feature: add context to state API

Signed-off-by: 1046102779 <seachen@tencent.com>
This commit is contained in:
1046102779 2022-10-21 17:25:14 +08:00
parent 0326f139c5
commit 5a367b401a
58 changed files with 846 additions and 822 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
package aerospike
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -110,7 +111,7 @@ func (aspike *Aerospike) Features() []state.Feature {
}
// Set stores value for a key to Aerospike. It honors ETag (for concurrency) and consistency settings.
func (aspike *Aerospike) Set(req *state.SetRequest) error {
func (aspike *Aerospike) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -162,7 +163,7 @@ func (aspike *Aerospike) Set(req *state.SetRequest) error {
}
// Get retrieves state from Aerospike with a key.
func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (aspike *Aerospike) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
asKey, err := as.NewKey(aspike.namespace, aspike.set, req.Key)
if err != nil {
return nil, err
@ -196,7 +197,7 @@ func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error)
}
// Delete performs a delete operation.
func (aspike *Aerospike) Delete(req *state.DeleteRequest) error {
func (aspike *Aerospike) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err

View File

@ -14,6 +14,7 @@ limitations under the License.
package tablestore
import (
"context"
"encoding/json"
"github.com/agrea/ptr"
@ -68,7 +69,7 @@ func (s *AliCloudTableStore) Features() []state.Feature {
return s.features
}
func (s *AliCloudTableStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (s *AliCloudTableStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
criteria := &tablestore.SingleRowQueryCriteria{
PrimaryKey: s.primaryKey(req.Key),
TableName: s.metadata.TableName,
@ -103,7 +104,7 @@ func (s *AliCloudTableStore) getResp(columns []*tablestore.AttributeColumn) *sta
return getResp
}
func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (s *AliCloudTableStore) BulkGet(ctx context.Context, reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// "len == 0": empty request, directly return empty response
if len(reqs) == 0 {
return true, []state.BulkGetResponse{}, nil
@ -139,7 +140,7 @@ func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.Bul
return true, responseList, nil
}
func (s *AliCloudTableStore) Set(req *state.SetRequest) error {
func (s *AliCloudTableStore) Set(ctx context.Context, req *state.SetRequest) error {
change := s.updateRowChange(req)
request := &tablestore.UpdateRowRequest{
@ -183,7 +184,7 @@ func unmarshal(val interface{}) []byte {
return []byte(output)
}
func (s *AliCloudTableStore) Delete(req *state.DeleteRequest) error {
func (s *AliCloudTableStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
change := s.deleteRowChange(req)
deleteRowReq := &tablestore.DeleteRowRequest{
@ -205,15 +206,15 @@ func (s *AliCloudTableStore) deleteRowChange(req *state.DeleteRequest) *tablesto
return change
}
func (s *AliCloudTableStore) BulkSet(reqs []state.SetRequest) error {
return s.batchWrite(reqs, nil)
func (s *AliCloudTableStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error {
return s.batchWrite(ctx, reqs, nil)
}
func (s *AliCloudTableStore) BulkDelete(reqs []state.DeleteRequest) error {
return s.batchWrite(nil, reqs)
func (s *AliCloudTableStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error {
return s.batchWrite(ctx, nil, reqs)
}
func (s *AliCloudTableStore) batchWrite(setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error {
func (s *AliCloudTableStore) batchWrite(ctx context.Context, setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error {
bathReq := &tablestore.BatchWriteRowRequest{
IsAtomic: true,
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package tablestore
import (
"context"
"testing"
"github.com/agrea/ptr"
@ -63,7 +64,7 @@ func TestReadAndWrite(t *testing.T) {
Value: "value of key",
ETag: ptr.String("the etag"),
}
err := store.Set(setReq)
err := store.Set(context.TODO(), setReq)
assert.Nil(t, err)
})
@ -71,7 +72,7 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{
Key: "theFirstKey",
}
resp, err := store.Get(getReq)
resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "value of key", string(resp.Data))
@ -83,7 +84,7 @@ func TestReadAndWrite(t *testing.T) {
Value: "1234",
ETag: ptr.String("the etag"),
}
err := store.Set(setReq)
err := store.Set(context.TODO(), setReq)
assert.Nil(t, err)
})
@ -91,14 +92,14 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{
Key: "theSecondKey",
}
resp, err := store.Get(getReq)
resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "1234", string(resp.Data))
})
t.Run("test BulkSet", func(t *testing.T) {
err := store.BulkSet([]state.SetRequest{{
err := store.BulkSet(context.TODO(), []state.SetRequest{{
Key: "theFirstKey",
Value: "666",
}, {
@ -110,7 +111,7 @@ func TestReadAndWrite(t *testing.T) {
})
t.Run("test BulkGet", func(t *testing.T) {
_, resp, err := store.BulkGet([]state.GetRequest{{
_, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey",
}, {
Key: "theSecondKey",
@ -126,12 +127,12 @@ func TestReadAndWrite(t *testing.T) {
req := &state.DeleteRequest{
Key: "theFirstKey",
}
err := store.Delete(req)
err := store.Delete(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("test BulkGet2", func(t *testing.T) {
_, resp, err := store.BulkGet([]state.GetRequest{{
_, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey",
}, {
Key: "theSecondKey",

View File

@ -14,6 +14,7 @@ limitations under the License.
package dynamodb
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
@ -79,7 +80,7 @@ func (d *StateStore) Features() []state.Feature {
}
// Get retrieves a dynamoDB item.
func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
input := &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong),
TableName: aws.String(d.table),
@ -90,7 +91,7 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
},
}
result, err := d.client.GetItem(input)
result, err := d.client.GetItemWithContext(ctx, input)
if err != nil {
return nil, err
}
@ -133,13 +134,13 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// BulkGet performs a bulk get operations.
func (d *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (d *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with dynamodb.BatchGetItem for performance
return false, nil, nil
}
// Set saves a dynamoDB item.
func (d *StateStore) Set(req *state.SetRequest) error {
func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
item, err := d.getItemFromReq(req)
if err != nil {
return err
@ -165,7 +166,7 @@ func (d *StateStore) Set(req *state.SetRequest) error {
input.ConditionExpression = &condExpr
}
_, err = d.client.PutItem(input)
_, err = d.client.PutItemWithContext(ctx, input)
if err != nil && haveEtag {
switch cErr := err.(type) {
case *dynamodb.ConditionalCheckFailedException:
@ -177,11 +178,11 @@ func (d *StateStore) Set(req *state.SetRequest) error {
}
// BulkSet performs a bulk set operation.
func (d *StateStore) BulkSet(req []state.SetRequest) error {
func (d *StateStore) BulkSet(ctx context.Context, req []state.SetRequest) error {
writeRequests := []*dynamodb.WriteRequest{}
if len(req) == 1 {
return d.Set(&req[0])
return d.Set(ctx, &req[0])
}
for _, r := range req {
@ -210,7 +211,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error {
requestItems := map[string][]*dynamodb.WriteRequest{}
requestItems[d.table] = writeRequests
_, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{
_, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: requestItems,
})
@ -218,7 +219,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error {
}
// Delete performs a delete operation.
func (d *StateStore) Delete(req *state.DeleteRequest) error {
func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
input := &dynamodb.DeleteItemInput{
Key: map[string]*dynamodb.AttributeValue{
"key": {
@ -238,7 +239,7 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error {
input.ExpressionAttributeValues = exprAttrValues
}
_, err := d.client.DeleteItem(input)
_, err := d.client.DeleteItemWithContext(ctx, input)
if err != nil {
switch cErr := err.(type) {
case *dynamodb.ConditionalCheckFailedException:
@ -250,11 +251,11 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error {
}
// BulkDelete performs a bulk delete operation.
func (d *StateStore) BulkDelete(req []state.DeleteRequest) error {
func (d *StateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
writeRequests := []*dynamodb.WriteRequest{}
if len(req) == 1 {
return d.Delete(&req[0])
return d.Delete(ctx, &req[0])
}
for _, r := range req {
@ -277,7 +278,7 @@ func (d *StateStore) BulkDelete(req []state.DeleteRequest) error {
requestItems := map[string][]*dynamodb.WriteRequest{}
requestItems[d.table] = writeRequests
_, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{
_, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: requestItems,
})

View File

@ -15,12 +15,14 @@ limitations under the License.
package dynamodb
import (
"context"
"fmt"
"strconv"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
@ -31,10 +33,10 @@ import (
)
type mockedDynamoDB struct {
GetItemFn func(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error)
PutItemFn func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error)
DeleteItemFn func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error)
BatchWriteItemFn func(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error)
GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error)
PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error)
DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error)
BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error)
dynamodbiface.DynamoDBAPI
}
@ -44,20 +46,20 @@ type DynamoDBItem struct {
TestAttributeName int64 `json:"testAttributeName"`
}
func (m *mockedDynamoDB) GetItem(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
return m.GetItemFn(input)
func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
return m.GetItemWithContextFn(ctx, input, op...)
}
func (m *mockedDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
return m.PutItemFn(input)
func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) {
return m.PutItemWithContextFn(ctx, input, op...)
}
func (m *mockedDynamoDB) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
return m.DeleteItemFn(input)
func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) {
return m.DeleteItemWithContextFn(ctx, input, op...)
}
func (m *mockedDynamoDB) BatchWriteItem(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) {
return m.BatchWriteItemFn(input)
func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) {
return m.BatchWriteItemWithContextFn(ctx, input, op...)
}
func TestInit(t *testing.T) {
@ -99,7 +101,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{
"key": {
@ -123,7 +125,7 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err)
assert.Equal(t, []byte("some value"), out.Data)
assert.Equal(t, "1bdead4badc0ffee", *out.ETag)
@ -131,7 +133,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{
"key": {
@ -159,7 +161,7 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err)
assert.Equal(t, []byte("some value"), out.Data)
assert.Equal(t, "1bdead4badc0ffee", *out.ETag)
@ -167,7 +169,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{
"key": {
@ -195,7 +197,7 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err)
assert.Nil(t, out.Data)
assert.Nil(t, out.ETag)
@ -203,7 +205,7 @@ func TestGet(t *testing.T) {
t.Run("Unsuccessfully get item", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return nil, fmt.Errorf("failed to retrieve data")
},
},
@ -215,14 +217,14 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.NotNil(t, err)
assert.Nil(t, out)
})
t.Run("Unsuccessfully with empty response", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{},
}, nil
@ -236,7 +238,7 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err)
assert.Nil(t, out.Data)
assert.Nil(t, out.ETag)
@ -244,7 +246,7 @@ func TestGet(t *testing.T) {
t.Run("Unsuccessfully with no required key", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) {
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{
"value2": {
@ -262,7 +264,7 @@ func TestGet(t *testing.T) {
Consistency: "strong",
},
}
out, err := ss.Get(req)
out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err)
assert.Empty(t, out.Data)
})
@ -276,7 +278,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"),
}, *input.Item["key"])
@ -301,14 +303,14 @@ func TestSet(t *testing.T) {
Value: "value",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Successfully set item with matching etag", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"),
}, *input.Item["key"])
@ -339,14 +341,14 @@ func TestSet(t *testing.T) {
Value: "value",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"),
}, *input.Item["key"])
@ -373,7 +375,7 @@ func TestSet(t *testing.T) {
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.NotNil(t, err)
switch tagErr := err.(type) {
case *state.ETagError:
@ -386,7 +388,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"),
}, *input.Item["key"])
@ -415,14 +417,14 @@ func TestSet(t *testing.T) {
Concurrency: state.FirstWrite,
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"),
}, *input.Item["key"])
@ -446,7 +448,7 @@ func TestSet(t *testing.T) {
Concurrency: state.FirstWrite,
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.NotNil(t, err)
switch err.(type) {
case *state.ETagError:
@ -458,7 +460,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item with ttl = -1", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, len(input.Item), 4)
result := DynamoDBItem{}
dynamodbattribute.UnmarshalMap(input.Item, &result)
@ -487,13 +489,13 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "-1",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, len(input.Item), 4)
result := DynamoDBItem{}
dynamodbattribute.UnmarshalMap(input.Item, &result)
@ -522,14 +524,14 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "180",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully set item", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
return nil, fmt.Errorf("unable to put item")
},
},
@ -540,13 +542,13 @@ func TestSet(t *testing.T) {
Value: "value",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.NotNil(t, err)
})
t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("someKey"),
}, *input.Item["key"])
@ -574,13 +576,13 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "180",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) {
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": {
S: aws.String("somekey"),
@ -613,7 +615,7 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "invalidvalue",
},
}
err := ss.Set(req)
err := ss.Set(context.TODO(), req)
assert.NotNil(t, err)
assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error())
})
@ -628,7 +630,7 @@ func TestBulkSet(t *testing.T) {
tableName := "table_name"
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{
{
@ -688,14 +690,14 @@ func TestBulkSet(t *testing.T) {
},
},
}
err := ss.BulkSet(req)
err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Successfully set items with ttl = -1", func(t *testing.T) {
tableName := "table_name"
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{
{
@ -761,14 +763,14 @@ func TestBulkSet(t *testing.T) {
},
},
}
err := ss.BulkSet(req)
err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Successfully set items with ttl", func(t *testing.T) {
tableName := "table_name"
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{}
// This might fail occasionally due to timestamp precision.
timestamp := time.Now().Unix() + 90
@ -836,13 +838,13 @@ func TestBulkSet(t *testing.T) {
},
},
}
err := ss.BulkSet(req)
err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully set items", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
return nil, fmt.Errorf("unable to bulk write items")
},
},
@ -861,7 +863,7 @@ func TestBulkSet(t *testing.T) {
},
},
}
err := ss.BulkSet(req)
err := ss.BulkSet(context.TODO(), req)
assert.NotNil(t, err)
})
}
@ -874,7 +876,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) {
DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": {
S: aws.String(req.Key),
@ -885,7 +887,7 @@ func TestDelete(t *testing.T) {
},
},
}
err := ss.Delete(req)
err := ss.Delete(context.TODO(), req)
assert.Nil(t, err)
})
@ -898,7 +900,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) {
DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": {
S: aws.String(req.Key),
@ -913,7 +915,7 @@ func TestDelete(t *testing.T) {
},
},
}
err := ss.Delete(req)
err := ss.Delete(context.TODO(), req)
assert.Nil(t, err)
})
@ -926,7 +928,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) {
DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": {
S: aws.String(req.Key),
@ -942,7 +944,7 @@ func TestDelete(t *testing.T) {
},
},
}
err := ss.Delete(req)
err := ss.Delete(context.TODO(), req)
assert.NotNil(t, err)
switch tagErr := err.(type) {
case *state.ETagError:
@ -955,7 +957,7 @@ func TestDelete(t *testing.T) {
t.Run("Unsuccessfully delete item", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) {
DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
return nil, fmt.Errorf("unable to delete item")
},
},
@ -963,7 +965,7 @@ func TestDelete(t *testing.T) {
req := &state.DeleteRequest{
Key: "key",
}
err := ss.Delete(req)
err := ss.Delete(context.TODO(), req)
assert.NotNil(t, err)
})
}
@ -973,7 +975,7 @@ func TestBulkDelete(t *testing.T) {
tableName := "table_name"
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{
{
@ -1012,13 +1014,13 @@ func TestBulkDelete(t *testing.T) {
Key: "key2",
},
}
err := ss.BulkDelete(req)
err := ss.BulkDelete(context.TODO(), req)
assert.Nil(t, err)
})
t.Run("Unsuccessfully delete items", func(t *testing.T) {
ss := StateStore{
client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) {
BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
return nil, fmt.Errorf("unable to bulk write items")
},
},
@ -1031,7 +1033,7 @@ func TestBulkDelete(t *testing.T) {
Key: "key",
},
}
err := ss.BulkDelete(req)
err := ss.BulkDelete(context.TODO(), req)
assert.NotNil(t, err)
})
}

View File

@ -130,21 +130,21 @@ func (r *StateStore) Features() []state.Feature {
}
// Delete the state.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("delete %s", req.Key)
return r.deleteFile(context.Background(), req)
return r.deleteFile(ctx, req)
}
// Get the state.
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("get %s", req.Key)
return r.readFile(context.Background(), req)
return r.readFile(ctx, req)
}
// Set the state.
func (r *StateStore) Set(req *state.SetRequest) error {
func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s", req.Key)
return r.writeFile(context.Background(), req)
return r.writeFile(ctx, req)
}
func (r *StateStore) Ping() error {

View File

@ -202,7 +202,7 @@ func (c *StateStore) Features() []state.Feature {
}
// Get retrieves a CosmosDB item.
func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (c *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
partitionKey := populatePartitionMetadata(req.Key, req.Metadata)
options := azcosmos.ItemOptions{}
@ -212,9 +212,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options)
cancel()
if err != nil {
var responseErr *azcore.ResponseError
if errors.As(err, &responseErr) && responseErr.ErrorCode == "NotFound" {
@ -263,7 +261,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Set saves a CosmosDB item.
func (c *StateStore) Set(req *state.SetRequest) error {
func (c *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -301,10 +299,8 @@ func (c *StateStore) Set(req *state.SetRequest) error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
pk := azcosmos.NewPartitionKeyString(partitionKey)
_, err = c.client.UpsertItem(ctx, pk, marsh, &options)
cancel()
if err != nil {
return err
}
@ -312,7 +308,7 @@ func (c *StateStore) Set(req *state.SetRequest) error {
}
// Delete performs a delete operation.
func (c *StateStore) Delete(req *state.DeleteRequest) error {
func (c *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -330,10 +326,8 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error {
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
pk := azcosmos.NewPartitionKeyString(partitionKey)
_, err = c.client.DeleteItem(ctx, pk, req.Key, &options)
cancel()
if err != nil && !isNotFoundError(err) {
c.logger.Debugf("Error from cosmos.DeleteDocument e=%e, e.Error=%s", err, err.Error())
if req.ETag != nil && *req.ETag != "" {
@ -346,7 +340,7 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error {
}
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error) {
func (c *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) (err error) {
if len(request.Operations) == 0 {
c.logger.Debugf("No Operations Provided")
return nil
@ -413,9 +407,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error)
c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
batchResponse, err := c.client.ExecuteTransactionalBatch(ctx, batch, nil)
cancel()
if err != nil {
return err
}
@ -440,7 +432,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error)
return nil
}
func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (c *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
q := &Query{}
qbuilder := query.NewQueryBuilder(q)
@ -448,7 +440,7 @@ func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error
return &state.QueryResponse{}, err
}
data, token, err := q.execute(c.client)
data, token, err := q.execute(ctx, c.client)
if err != nil {
return nil, err
}

View File

@ -144,7 +144,7 @@ func (q *Query) setNextParameter(val string) string {
return pname
}
func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) {
func (q *Query) execute(ctx context.Context, client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) {
opts := &azcosmos.QueryOptions{}
resultLimit := q.limit
@ -160,9 +160,7 @@ func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, st
token := ""
for queryPager.More() {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
queryResponse, innerErr := queryPager.NextPage(ctx)
cancel()
if innerErr != nil {
return nil, "", innerErr
}

View File

@ -163,10 +163,10 @@ func (r *StateStore) Features() []state.Feature {
return r.features
}
func (r *StateStore) Delete(req *state.DeleteRequest) error {
func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("delete %s", req.Key)
err := r.deleteRow(req)
err := r.deleteRow(ctx, req)
if err != nil {
if req.ETag != nil {
return state.NewETagError(state.ETagMismatch, err)
@ -179,12 +179,10 @@ func (r *StateStore) Delete(req *state.DeleteRequest) error {
return err
}
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("fetching %s", req.Key)
pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode)
getContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resp, err := r.client.GetEntity(getContext, pk, rk, nil)
resp, err := r.client.GetEntity(ctx, pk, rk, nil)
if err != nil {
if isNotFoundError(err) {
return &state.GetResponse{}, nil
@ -200,10 +198,10 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, err
}
func (r *StateStore) Set(req *state.SetRequest) error {
func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s", req.Key)
err := r.writeRow(req)
err := r.writeRow(ctx, req)
return err
}
@ -254,30 +252,26 @@ func getTablesMetadata(metadata map[string]string) (*tablesMetadata, error) {
return &meta, nil
}
func (r *StateStore) writeRow(req *state.SetRequest) error {
func (r *StateStore) writeRow(ctx context.Context, req *state.SetRequest) error {
marshalledEntity, err := r.marshal(req)
if err != nil {
return err
}
writeContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// InsertOrReplace does not support ETag concurrency, therefore we will use Insert to check for key existence
// and then use Update to update the key if it exists with the specified ETag
_, err = r.client.AddEntity(writeContext, marshalledEntity, nil)
_, err = r.client.AddEntity(ctx, marshalledEntity, nil)
if err != nil {
// If Insert failed because item already exists, try to Update instead per Upsert semantics
if isEntityAlreadyExistsError(err) {
updateContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Always Update using the etag when provided even if Concurrency != FirstWrite.
// Today the presence of etag takes precedence over Concurrency.
// In the future #2739 will impose a breaking change which must disallow the use of etag when not using FirstWrite.
if req.ETag != nil && *req.ETag != "" {
etag := azcore.ETag(*req.ETag)
_, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{
_, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{
IfMatch: &etag,
UpdateMode: aztables.UpdateModeReplace,
})
@ -295,7 +289,7 @@ func (r *StateStore) writeRow(req *state.SetRequest) error {
return state.NewETagError(state.ETagMismatch, errors.New("update with Concurrency.FirstWrite without ETag"))
} else {
// Finally, last write semantics without ETag should always perform a force update.
_, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{
_, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{
IfMatch: nil, // this is the same as "*" matching all ETags
UpdateMode: aztables.UpdateModeReplace,
})
@ -336,19 +330,16 @@ func isTableAlreadyExistsError(err error) bool {
return false
}
func (r *StateStore) deleteRow(req *state.DeleteRequest) error {
func (r *StateStore) deleteRow(ctx context.Context, req *state.DeleteRequest) error {
pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode)
deleteContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if req.ETag != nil {
azcoreETag := azcore.ETag(*req.ETag)
_, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag})
_, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag})
return err
}
all := azcore.ETagAny
_, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all})
_, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all})
return err
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package cassandra
import (
"context"
"errors"
"fmt"
"strconv"
@ -230,12 +231,12 @@ func getCassandraMetadata(metadata state.Metadata) (*cassandraMetadata, error) {
}
// Delete performs a delete operation.
func (c *Cassandra) Delete(req *state.DeleteRequest) error {
return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).Exec()
func (c *Cassandra) Delete(ctx context.Context, req *state.DeleteRequest) error {
return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Exec()
}
// Get retrieves state from cassandra with a key.
func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (c *Cassandra) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
session := c.session
if req.Options.Consistency == state.Strong {
@ -254,7 +255,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) {
session = sess
}
results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).Iter().SliceMap()
results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Iter().SliceMap()
if err != nil {
return nil, err
}
@ -269,7 +270,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Set saves state into cassandra.
func (c *Cassandra) Set(req *state.SetRequest) error {
func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error {
var bt []byte
b, ok := req.Value.([]byte)
if ok {
@ -302,10 +303,10 @@ func (c *Cassandra) Set(req *state.SetRequest) error {
}
if ttl != nil {
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).Exec()
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).WithContext(ctx).Exec()
}
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).Exec()
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).WithContext(ctx).Exec()
}
func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session, error) {

View File

@ -14,6 +14,8 @@ limitations under the License.
package cockroachdb
import (
"context"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
)
@ -53,18 +55,18 @@ func (c *CockroachDB) Features() []state.Feature {
}
// Delete removes an entity from the store.
func (c *CockroachDB) Delete(req *state.DeleteRequest) error {
return c.dbaccess.Delete(req)
func (c *CockroachDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
return c.dbaccess.Delete(ctx, req)
}
// Get returns an entity from store.
func (c *CockroachDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
return c.dbaccess.Get(req)
func (c *CockroachDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return c.dbaccess.Get(ctx, req)
}
// Set adds/updates an entity on store.
func (c *CockroachDB) Set(req *state.SetRequest) error {
return c.dbaccess.Set(req)
func (c *CockroachDB) Set(ctx context.Context, req *state.SetRequest) error {
return c.dbaccess.Set(ctx, req)
}
// Ping checks if database is available.
@ -73,29 +75,29 @@ func (c *CockroachDB) Ping() error {
}
// BulkDelete removes multiple entries from the store.
func (c *CockroachDB) BulkDelete(req []state.DeleteRequest) error {
return c.dbaccess.BulkDelete(req)
func (c *CockroachDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return c.dbaccess.BulkDelete(ctx, req)
}
// BulkGet performs a bulks get operations.
func (c *CockroachDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (c *CockroachDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance.
return false, nil, nil
}
// BulkSet adds/updates multiple entities on store.
func (c *CockroachDB) BulkSet(req []state.SetRequest) error {
return c.dbaccess.BulkSet(req)
func (c *CockroachDB) BulkSet(ctx context.Context, req []state.SetRequest) error {
return c.dbaccess.BulkSet(ctx, req)
}
// Multi handles multiple transactions. Implements TransactionalStore.
func (c *CockroachDB) Multi(request *state.TransactionalStateRequest) error {
return c.dbaccess.ExecuteMulti(request)
func (c *CockroachDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
return c.dbaccess.ExecuteMulti(ctx, request)
}
// Query executes a query against store.
func (c *CockroachDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
return c.dbaccess.Query(req)
func (c *CockroachDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return c.dbaccess.Query(ctx, req)
}
// Close implements io.Closer.

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
@ -96,12 +97,12 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
}
// Set makes an insert or update to the database.
func (p *cockroachDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
func (p *cockroachDBAccess) setValue(ctx context.Context, req *state.SetRequest) error {
p.logger.Debug("Setting state value in CockroachDB")
value, isBinary, err := validateAndReturnValue(req)
@ -114,7 +115,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.ETag == nil {
result, err = p.db.Exec(fmt.Sprintf(
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`,
tableName), req.Key, value, isBinary)
@ -127,7 +128,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
etag := uint32(etag64)
// When an etag is provided do an update - no insert.
result, err = p.db.Exec(fmt.Sprintf(
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1
WHERE key = $3 AND etag = $4;`,
tableName), value, isBinary, req.Key, etag)
@ -149,7 +150,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
return nil
}
func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error {
func (p *cockroachDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin()
if err != nil {
@ -159,7 +160,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error {
if len(req) > 0 {
for _, s := range req {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(&sa)
err = p.Set(ctx, &sa)
if err != nil {
tx.Rollback()
@ -174,7 +175,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error {
}
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from CockroachDB")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
@ -183,7 +184,7 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
var value string
var isBinary bool
var etag int
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if errors.Is(err, sql.ErrNoRows) {
@ -222,12 +223,12 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
}
// Delete removes an item from the state store.
func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req)
func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
func (p *cockroachDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from CockroachDB")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
@ -237,7 +238,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
var err error
if req.ETag == nil {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
} else {
var etag64 uint64
etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
@ -246,7 +247,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
}
etag := uint32(etag64)
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag)
}
if err != nil {
@ -265,7 +266,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
return nil
}
func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error {
func (p *cockroachDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin()
if err != nil {
@ -275,7 +276,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 {
for _, d := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da)
err = p.Delete(ctx, &da)
if err != nil {
tx.Rollback()
@ -289,7 +290,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error {
return err
}
func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error {
func (p *cockroachDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin()
@ -308,7 +309,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
return err
}
err = p.Set(&setReq)
err = p.Set(ctx, &setReq)
if err != nil {
tx.Rollback()
return err
@ -323,7 +324,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
return err
}
err = p.Delete(&delReq)
err = p.Delete(ctx, &delReq)
if err != nil {
tx.Rollback()
return err
@ -341,7 +342,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
}
// Query executes a query against store.
func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (p *cockroachDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from CockroachDB")
stateQuery := &Query{
@ -361,7 +362,7 @@ func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse
p.logger.Debug("Query: " + stateQuery.query)
data, token, err := stateQuery.execute(p.logger, p.db)
data, token, err := stateQuery.execute(ctx, p.logger, p.db)
if err != nil {
return &state.QueryResponse{
Results: []state.QueryItem{},

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb
import (
"context"
"database/sql"
"testing"
@ -109,7 +110,7 @@ func TestMultiWithNoRequests(t *testing.T) {
var operations []state.TransactionalStateOperation
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -133,7 +134,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -158,7 +159,7 @@ func TestValidSetRequest(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -182,7 +183,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -206,7 +207,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -231,7 +232,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -255,7 +256,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -279,7 +280,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
})
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -311,7 +312,7 @@ func TestMultiOperationOrder(t *testing.T) {
)
// Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -334,7 +335,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
})
// Act
err := m.roachDba.BulkSet(sets)
err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert
assert.NotNil(t, err)
@ -356,7 +357,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
})
// Act
err := m.roachDba.BulkSet(sets)
err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert
assert.NotNil(t, err)
@ -379,7 +380,7 @@ func TestValidBulkSet(t *testing.T) {
})
// Act
err := m.roachDba.BulkSet(sets)
err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert
assert.Nil(t, err)
@ -400,7 +401,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
})
// Act
err := m.roachDba.BulkDelete(deletes)
err := m.roachDba.BulkDelete(context.TODO(), deletes)
// Assert
assert.NotNil(t, err)
@ -422,7 +423,7 @@ func TestValidBulkDelete(t *testing.T) {
})
// Act
err := m.roachDba.BulkDelete(deletes)
err := m.roachDba.BulkDelete(context.TODO(), deletes)
// Assert
assert.Nil(t, err)

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -211,7 +212,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *CockroachDB) {
Consistency: "",
},
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, err)
}
@ -239,7 +240,7 @@ func multiWithSetOnly(t *testing.T, pgs *CockroachDB) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
Metadata: nil,
})
@ -280,7 +281,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *CockroachDB) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
Metadata: nil,
})
@ -341,7 +342,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *CockroachDB) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
Metadata: nil,
})
@ -376,7 +377,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *CockroachDB) {
Consistency: "",
},
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
}
@ -392,7 +393,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *CockroachDB) {
Consistency: "",
},
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
}
@ -415,7 +416,7 @@ func newItemWithEtagFails(t *testing.T, pgs *CockroachDB) {
ContentType: nil,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -449,7 +450,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *CockroachDB) {
},
ContentType: nil,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -501,7 +502,7 @@ func getItemWithNoKey(t *testing.T, pgs *CockroachDB) {
},
}
response, getErr := pgs.Get(getReq)
response, getErr := pgs.Get(context.TODO(), getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
}
@ -544,7 +545,7 @@ func setItemWithNoKey(t *testing.T, pgs *CockroachDB) {
ContentType: nil,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -563,7 +564,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) {
},
}
err := pgs.BulkSet(setReq)
err := pgs.BulkSet(context.TODO(), setReq)
assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key))
@ -577,7 +578,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) {
},
}
err = pgs.BulkDelete(deleteReq)
err = pgs.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key))
@ -644,7 +645,7 @@ func setItem(t *testing.T, pgs *CockroachDB, key string, value interface{}, etag
ContentType: nil,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.Nil(t, err)
itemExists := storeItemExists(t, key)
assert.True(t, itemExists)
@ -661,7 +662,7 @@ func getItem(t *testing.T, pgs *CockroachDB, key string) (*state.GetResponse, *f
Metadata: map[string]string{},
}
response, getErr := pgs.Get(getReq)
response, getErr := pgs.Get(context.TODO(), getReq)
assert.Nil(t, getErr)
assert.NotNil(t, response)
outputObject := &fakeItem{
@ -685,7 +686,7 @@ func deleteItem(t *testing.T, pgs *CockroachDB, key string, etag *string) {
Metadata: map[string]string{},
}
deleteErr := pgs.Delete(deleteReq)
deleteErr := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key))
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb
import (
"context"
"database/sql"
"fmt"
"strconv"
@ -133,8 +134,8 @@ func (q *Query) Finalize(filters string, storeQuery *query.Query) error {
return nil
}
func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.Query(q.query, q.params...)
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.QueryContext(ctx, q.query, q.params...)
if err != nil {
return nil, "", fmt.Errorf("query executes '%s' failed: %w", q.query, err)
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -42,37 +43,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil
}
func (m *fakeDBaccess) Set(req *state.SetRequest) error {
func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true
return nil
}
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true
return nil, nil
}
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error {
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
m.deleteExecuted = true
return nil
}
func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error {
func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
return nil
}
func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error {
func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return nil
}
func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error {
func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error {
return nil
}
func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return nil, nil
}

View File

@ -13,18 +13,22 @@ limitations under the License.
package cockroachdb
import "github.com/dapr/components-contrib/state"
import (
"context"
"github.com/dapr/components-contrib/state"
)
// dbAccess is a private interface which enables unit testing of CockroachDB.
type dbAccess interface {
Init(metadata state.Metadata) error
Set(req *state.SetRequest) error
BulkSet(req []state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error
BulkDelete(req []state.DeleteRequest) error
ExecuteMulti(req *state.TransactionalStateRequest) error
Query(req *state.QueryRequest) (*state.QueryResponse, error)
Set(ctx context.Context, req *state.SetRequest) error
BulkSet(ctx context.Context, req []state.SetRequest) error
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(ctx context.Context, req *state.DeleteRequest) error
BulkDelete(ctx context.Context, req []state.DeleteRequest) error
ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error
Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
Ping() error
Close() error
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package couchbase
import (
"context"
"errors"
"fmt"
"strconv"
@ -144,7 +145,7 @@ func (cbs *Couchbase) Features() []state.Feature {
}
// Set stores value for a key to couchbase. It honors ETag (for concurrency) and consistency settings.
func (cbs *Couchbase) Set(req *state.SetRequest) error {
func (cbs *Couchbase) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -188,7 +189,7 @@ func (cbs *Couchbase) Set(req *state.SetRequest) error {
}
// Get retrieves state from couchbase with a key.
func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (cbs *Couchbase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
var data interface{}
cas, err := cbs.bucket.Get(req.Key, &data)
if err != nil {
@ -206,7 +207,7 @@ func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Delete performs a delete operation.
func (cbs *Couchbase) Delete(req *state.DeleteRequest) error {
func (cbs *Couchbase) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err

View File

@ -93,12 +93,12 @@ func (f *Firestore) Features() []state.Feature {
}
// Get retrieves state from Firestore with a key (Always strong consistency).
func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (f *Firestore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
key := req.Key
entityKey := datastore.NameKey(f.entityKind, key, nil)
var entity StateEntity
err := f.client.Get(context.Background(), entityKey, &entity)
err := f.client.Get(ctx, entityKey, &entity)
if err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) {
return nil, err
@ -111,7 +111,7 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil
}
func (f *Firestore) setValue(req *state.SetRequest) error {
func (f *Firestore) setValue(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -128,7 +128,6 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
entity := &StateEntity{
Value: v,
}
ctx := context.Background()
key := datastore.NameKey(f.entityKind, req.Key, nil)
_, err = f.client.Put(ctx, key, entity)
@ -141,12 +140,11 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
}
// Set saves state into Firestore with retry.
func (f *Firestore) Set(req *state.SetRequest) error {
return state.SetWithOptions(f.setValue, req)
func (f *Firestore) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, f.setValue, req)
}
func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
ctx := context.Background()
func (f *Firestore) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
key := datastore.NameKey(f.entityKind, req.Key, nil)
err := f.client.Delete(ctx, key)
@ -158,8 +156,8 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
}
// Delete performs a delete operation.
func (f *Firestore) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(f.deleteValue, req)
func (f *Firestore) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, f.deleteValue, req)
}
func getFirestoreMetadata(metadata state.Metadata) (*firestoreMetadata, error) {

View File

@ -14,6 +14,7 @@ limitations under the License.
package consul
import (
"context"
"encoding/json"
"fmt"
@ -102,11 +103,12 @@ func metadataToConfig(connInfo map[string]string) (*consulConfig, error) {
}
// Get retrieves a Consul KV item.
func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (c *Consul) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
queryOpts := &api.QueryOptions{}
if req.Options.Consistency == state.Strong {
queryOpts.RequireConsistent = true
}
queryOpts = queryOpts.WithContext(ctx)
resp, queryMeta, err := c.client.KV().Get(fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key), queryOpts)
if err != nil {
@ -124,7 +126,7 @@ func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Set saves a Consul KV item.
func (c *Consul) Set(req *state.SetRequest) error {
func (c *Consul) Set(ctx context.Context, req *state.SetRequest) error {
var reqValByte []byte
b, ok := req.Value.([]byte)
if ok {
@ -135,10 +137,12 @@ func (c *Consul) Set(req *state.SetRequest) error {
keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key)
writeOptions := new(api.WriteOptions)
writeOptions = writeOptions.WithContext(ctx)
_, err := c.client.KV().Put(&api.KVPair{
Key: keyWithPath,
Value: reqValByte,
}, nil)
}, writeOptions)
if err != nil {
return fmt.Errorf("couldn't set key %s: %s", keyWithPath, err)
}
@ -147,9 +151,11 @@ func (c *Consul) Set(req *state.SetRequest) error {
}
// Delete performes a Consul KV delete operation.
func (c *Consul) Delete(req *state.DeleteRequest) error {
func (c *Consul) Delete(ctx context.Context, req *state.DeleteRequest) error {
keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key)
_, err := c.client.KV().Delete(keyWithPath, nil)
writeOptions := new(api.WriteOptions)
writeOptions = writeOptions.WithContext(ctx)
_, err := c.client.KV().Delete(keyWithPath, writeOptions)
if err != nil {
return fmt.Errorf("couldn't delete key %s: %s", keyWithPath, err)
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package hazelcast
import (
"context"
"errors"
"fmt"
"strings"
@ -91,7 +92,7 @@ func (store *Hazelcast) Features() []state.Feature {
}
// Set stores value for a key to Hazelcast.
func (store *Hazelcast) Set(req *state.SetRequest) error {
func (store *Hazelcast) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req)
if err != nil {
return err
@ -117,7 +118,7 @@ func (store *Hazelcast) Set(req *state.SetRequest) error {
}
// Get retrieves state from Hazelcast with a key.
func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (store *Hazelcast) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
resp, err := store.hzMap.Get(req.Key)
if err != nil {
return nil, fmt.Errorf("hazelcast error: failed to get value for %s: %s", req.Key, err)
@ -138,7 +139,7 @@ func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Delete performs a delete operation.
func (store *Hazelcast) Delete(req *state.DeleteRequest) error {
func (store *Hazelcast) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err

View File

@ -73,7 +73,7 @@ func (store *inMemoryStore) Features() []state.Feature {
return []state.Feature{state.FeatureETag, state.FeatureTransactional}
}
func (store *inMemoryStore) Delete(req *state.DeleteRequest) error {
func (store *inMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
// step1: validate parameters
if err := store.doDeleteValidateParameters(req); err != nil {
return err
@ -90,7 +90,7 @@ func (store *inMemoryStore) Delete(req *state.DeleteRequest) error {
// step3: do really delete
// this operation won't fail
store.doDelete(req.Key)
store.doDelete(ctx, req.Key)
return nil
}
@ -117,11 +117,11 @@ func (store *inMemoryStore) doValidateEtag(key string, etag *string, concurrency
return nil
}
func (store *inMemoryStore) doDelete(key string) {
func (store *inMemoryStore) doDelete(ctx context.Context, key string) {
delete(store.items, key)
}
func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error {
func (store *inMemoryStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
if len(req) == 0 {
return nil
}
@ -148,15 +148,15 @@ func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error {
// step3: do really delete
for _, dr := range req {
store.doDelete(dr.Key)
store.doDelete(ctx, dr.Key)
}
return nil
}
func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
item := store.doGetWithReadLock(req.Key)
func (store *inMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
item := store.doGetWithReadLock(ctx, req.Key)
if item != nil && isExpired(item.expire) {
item = store.doGetWithWriteLock(req.Key)
item = store.doGetWithWriteLock(ctx, req.Key)
}
if item == nil {
@ -165,14 +165,14 @@ func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, erro
return &state.GetResponse{Data: unmarshal(item.data), ETag: item.etag}, nil
}
func (store *inMemoryStore) doGetWithReadLock(key string) *inMemStateStoreItem {
func (store *inMemoryStore) doGetWithReadLock(ctx context.Context, key string) *inMemStateStoreItem {
store.lock.RLock()
defer store.lock.RUnlock()
return store.items[key]
}
func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem {
func (store *inMemoryStore) doGetWithWriteLock(ctx context.Context, key string) *inMemStateStoreItem {
store.lock.Lock()
defer store.lock.Unlock()
// get item and check expired again to avoid if item changed between we got this write-lock
@ -181,7 +181,7 @@ func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem
return nil
}
if isExpired(item.expire) {
store.doDelete(key)
store.doDelete(ctx, key)
return nil
}
return item
@ -194,11 +194,11 @@ func isExpired(expire int64) bool {
return time.Now().UnixMilli() > expire
}
func (store *inMemoryStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (store *inMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil
}
func (store *inMemoryStore) Set(req *state.SetRequest) error {
func (store *inMemoryStore) Set(ctx context.Context, req *state.SetRequest) error {
// step1: validate parameters
ttlInSeconds, err := store.doSetValidateParameters(req)
if err != nil {
@ -217,7 +217,7 @@ func (store *inMemoryStore) Set(req *state.SetRequest) error {
// step3: do really set
// this operation won't fail
store.doSet(req.Key, b, req.ETag, ttlInSeconds)
store.doSet(ctx, req.Key, b, req.ETag, ttlInSeconds)
return nil
}
@ -253,7 +253,7 @@ func doParseTTLInSeconds(metadata map[string]string) (int, error) {
return i, nil
}
func (store *inMemoryStore) doSet(key string, data []byte, etag *string, ttlInSeconds int) {
func (store *inMemoryStore) doSet(ctx context.Context, key string, data []byte, etag *string, ttlInSeconds int) {
store.items[key] = &inMemStateStoreItem{
data: data,
etag: etag,
@ -268,7 +268,7 @@ type innerSetRequest struct {
data []byte
}
func (store *inMemoryStore) BulkSet(req []state.SetRequest) error {
func (store *inMemoryStore) BulkSet(ctx context.Context, req []state.SetRequest) error {
if len(req) == 0 {
return nil
}
@ -305,12 +305,12 @@ func (store *inMemoryStore) BulkSet(req []state.SetRequest) error {
// step3: do really set
// these operations won't fail
for _, innerSetRequest := range innerSetRequestList {
store.doSet(innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds)
store.doSet(ctx, innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds)
}
return nil
}
func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) error {
func (store *inMemoryStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
if len(request.Operations) == 0 {
return nil
}
@ -366,10 +366,10 @@ func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) erro
for _, o := range request.Operations {
if o.Operation == state.Upsert {
s := o.Request.(innerSetRequest)
store.doSet(s.req.Key, s.data, s.req.ETag, s.ttlInSeconds)
store.doSet(ctx, s.req.Key, s.data, s.req.ETag, s.ttlInSeconds)
} else if o.Operation == state.Delete {
d := o.Request.(state.DeleteRequest)
store.doDelete(d.Key)
store.doDelete(ctx, d.Key)
}
}
return nil
@ -406,7 +406,7 @@ func (store *inMemoryStore) doCleanExpiredItems() {
for key, item := range store.items {
if isExpired(item.expire) {
store.doDelete(key)
store.doDelete(context.Background(), key)
}
}
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package inmemory
import (
"context"
"testing"
"time"
@ -43,13 +44,13 @@ func TestReadAndWrite(t *testing.T) {
Value: valueA,
ETag: ptr.String("the etag"),
}
err := store.Set(setReq)
err := store.Set(context.TODO(), setReq)
assert.Nil(t, err)
// get after set
getReq := &state.GetRequest{
Key: keyA,
}
resp, err := store.Get(getReq)
resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, valueA, string(resp.Data))
@ -62,7 +63,7 @@ func TestReadAndWrite(t *testing.T) {
Value: valueA,
Metadata: map[string]string{"ttlInSeconds": "1"},
}
err := store.Set(setReq)
err := store.Set(context.TODO(), setReq)
assert.Nil(t, err)
// simulate expiration
time.Sleep(2 * time.Second)
@ -70,7 +71,7 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{
Key: keyA,
}
resp, err := store.Get(getReq)
resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Nil(t, resp.Data)
@ -84,20 +85,20 @@ func TestReadAndWrite(t *testing.T) {
Value: "1234",
ETag: ptr.String("the etag"),
}
err := store.Set(setReq)
err := store.Set(context.TODO(), setReq)
assert.Nil(t, err)
// get
getReq := &state.GetRequest{
Key: "theSecondKey",
}
resp, err := store.Get(getReq)
resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "1234", string(resp.Data))
})
t.Run("BulkSet two keys", func(t *testing.T) {
err := store.BulkSet([]state.SetRequest{{
err := store.BulkSet(context.TODO(), []state.SetRequest{{
Key: "theFirstKey",
Value: "666",
}, {
@ -109,7 +110,7 @@ func TestReadAndWrite(t *testing.T) {
})
t.Run("BulkGet fails when not supported", func(t *testing.T) {
supportBulk, _, err := store.BulkGet([]state.GetRequest{{
supportBulk, _, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey",
}, {
Key: "theSecondKey",
@ -123,7 +124,7 @@ func TestReadAndWrite(t *testing.T) {
req := &state.DeleteRequest{
Key: "theFirstKey",
}
err := store.Delete(req)
err := store.Delete(context.TODO(), req)
assert.Nil(t, err)
})
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package jetstream
import (
"context"
"fmt"
"strings"
@ -97,7 +98,7 @@ func (js *StateStore) Features() []state.Feature {
}
// Get retrieves state with a key.
func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (js *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
entry, err := js.bucket.Get(escape(req.Key))
if err != nil {
return nil, err
@ -109,14 +110,14 @@ func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Set stores value for a key.
func (js *StateStore) Set(req *state.SetRequest) error {
func (js *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
bt, _ := utils.Marshal(req.Value, js.json.Marshal)
_, err := js.bucket.Put(escape(req.Key), bt)
return err
}
// Delete performs a delete operation.
func (js *StateStore) Delete(req *state.DeleteRequest) error {
func (js *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
return js.bucket.Delete(escape(req.Key))
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package jetstream
import (
"context"
"encoding/json"
"fmt"
"reflect"
@ -112,7 +113,7 @@ func TestSetGetAndDelete(t *testing.T) {
"dkey": "dvalue",
}
err = store.Set(&state.SetRequest{
err = store.Set(context.TODO(), &state.SetRequest{
Key: tkey,
Value: tData,
})
@ -121,7 +122,7 @@ func TestSetGetAndDelete(t *testing.T) {
return
}
resp, err := store.Get(&state.GetRequest{
resp, err := store.Get(context.TODO(), &state.GetRequest{
Key: tkey,
})
if err != nil {
@ -134,7 +135,7 @@ func TestSetGetAndDelete(t *testing.T) {
t.Fatal("Response data does not match written data\n")
}
err = store.Delete(&state.DeleteRequest{
err = store.Delete(context.TODO(), &state.DeleteRequest{
Key: tkey,
})
if err != nil {
@ -142,7 +143,7 @@ func TestSetGetAndDelete(t *testing.T) {
return
}
_, err = store.Get(&state.GetRequest{
_, err = store.Get(context.TODO(), &state.GetRequest{
Key: tkey,
})
if err == nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package memcached
import (
"context"
"errors"
"fmt"
"strconv"
@ -139,7 +140,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) {
return nil, nil
}
func (m *Memcached) setValue(req *state.SetRequest) error {
func (m *Memcached) setValue(ctx context.Context, req *state.SetRequest) error {
var bt []byte
ttl, err := m.parseTTL(req)
if err != nil {
@ -159,7 +160,7 @@ func (m *Memcached) setValue(req *state.SetRequest) error {
return nil
}
func (m *Memcached) Delete(req *state.DeleteRequest) error {
func (m *Memcached) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := m.client.Delete(req.Key)
if err != nil {
if err == memcache.ErrCacheMiss {
@ -171,7 +172,7 @@ func (m *Memcached) Delete(req *state.DeleteRequest) error {
return nil
}
func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *Memcached) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
item, err := m.client.Get(req.Key)
if err != nil {
// Return nil for status 204
@ -187,6 +188,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil
}
func (m *Memcached) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req)
func (m *Memcached) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, m.setValue, req)
}

View File

@ -159,10 +159,7 @@ func (m *MongoDB) Features() []state.Feature {
}
// Set saves state into MongoDB.
func (m *MongoDB) Set(req *state.SetRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
func (m *MongoDB) Set(ctx context.Context, req *state.SetRequest) error {
err := m.setInternal(ctx, req)
if err != nil {
return err
@ -205,12 +202,9 @@ func (m *MongoDB) setInternal(ctx context.Context, req *state.SetRequest) error
}
// Get retrieves state from MongoDB with a key.
func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *MongoDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
var result Item
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
filter := bson.M{id: req.Key}
err := m.collection.FindOne(ctx, filter).Decode(&result)
if err != nil {
@ -264,10 +258,7 @@ func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// Delete performs a delete operation.
func (m *MongoDB) Delete(req *state.DeleteRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
func (m *MongoDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := m.deleteInternal(ctx, req)
if err != nil {
return err
@ -294,18 +285,18 @@ func (m *MongoDB) deleteInternal(ctx context.Context, req *state.DeleteRequest)
}
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (m *MongoDB) Multi(request *state.TransactionalStateRequest) error {
func (m *MongoDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
sess, err := m.client.StartSession()
txnOpts := options.Transaction().SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
defer sess.EndSession(context.Background())
defer sess.EndSession(ctx)
if err != nil {
return fmt.Errorf("error in starting the transaction: %s", err)
}
sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) {
sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) {
err = m.doTransaction(sessCtx, request.Operations)
return nil, err
@ -336,10 +327,7 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state
}
// Query executes a query against store.
func (m *MongoDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
func (m *MongoDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
q := &Query{}
qbuilder := query.NewQueryBuilder(q)
if err := qbuilder.BuildQuery(&req.Query); err != nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package mysql
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
@ -279,13 +280,13 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
// Delete removes an entity from the store
// Store Interface.
func (m *MySQL) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(m.deleteValue, req)
func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, m.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
func (m *MySQL) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql")
if req.Key == "" {
@ -296,11 +297,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
var result sql.Result
if req.ETag == nil || *req.ETag == "" {
result, err = m.db.Exec(fmt.Sprintf(
result, err = m.db.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ?`,
m.tableName), req.Key)
} else {
result, err = m.db.Exec(fmt.Sprintf(
result, err = m.db.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ? and eTag = ?`,
m.tableName), req.Key, *req.ETag)
}
@ -323,7 +324,7 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
// BulkDelete removes multiple entries from the store
// Store Interface.
func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
m.logger.Debug("Executing BulkDelete request")
tx, err := m.db.Begin()
@ -334,7 +335,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 {
for _, d := range req {
da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Delete(&da)
err = m.Delete(ctx, &da)
if err != nil {
tx.Rollback()
@ -350,7 +351,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
// Get returns an entity from store
// Store Interface.
func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *MySQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.logger.Debug("Getting state value from MySql")
if req.Key == "" {
@ -368,7 +369,7 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
m.tableName, // m.tableName is sanitized
)
err := m.db.QueryRow(query, req.Key).Scan(&value, &eTag, &isBinary)
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
if err != nil {
// If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) {
@ -410,13 +411,13 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
// Set adds/updates an entity on store
// Store Interface.
func (m *MySQL) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req)
func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, m.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func.
func (m *MySQL) setValue(req *state.SetRequest) error {
func (m *MySQL) setValue(ctx context.Context, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql")
err := state.CheckRequestOptions(req.Options)
@ -457,7 +458,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary)
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
} else if req.ETag != nil && *req.ETag != "" {
// When an eTag is provided do an update - not insert
//nolint:gosec
@ -465,7 +466,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, eTag, isBinary, req.Key, *req.ETag)
result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
@ -474,7 +475,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
}
if err != nil {
@ -508,7 +509,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
// BulkSet adds/updates multiple entities on store
// Store Interface.
func (m *MySQL) BulkSet(req []state.SetRequest) error {
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
m.logger.Debug("Executing BulkSet request")
tx, err := m.db.Begin()
@ -518,7 +519,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
if len(req) > 0 {
for i := range req {
err = m.Set(&req[i])
err = m.Set(ctx, &req[i])
if err != nil {
tx.Rollback()
return err
@ -531,7 +532,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
// Multi handles multiple transactions.
// TransactionalStore Interface.
func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
m.logger.Debug("Executing Multi request")
tx, err := m.db.Begin()
@ -548,7 +549,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
return err
}
err = m.Set(&setReq)
err = m.Set(ctx, &setReq)
if err != nil {
_ = tx.Rollback()
return err
@ -561,7 +562,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
return err
}
err = m.Delete(&delReq)
err = m.Delete(ctx, &delReq)
if err != nil {
_ = tx.Rollback()
return err
@ -604,7 +605,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR
}
// BulkGet performs a bulks get operations.
func (m *MySQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (m *MySQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// by default, the store doesn't support bulk get
// return false so daprd will fallback to call get() method one by one
return false, nil, nil

View File

@ -15,6 +15,7 @@ limitations under the License.
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
@ -205,7 +206,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "",
}
response, getErr := mys.Get(getReq)
response, getErr := mys.Get(context.TODO(), getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
})
@ -242,7 +243,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "",
}
err := mys.Set(setReq)
err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Error was not nil when setting item with no key.")
})
@ -302,7 +303,7 @@ func TestMySQLIntegration(t *testing.T) {
Value: newValue,
}
err := mys.Set(setReq)
err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Error was not thrown using old eTag")
})
@ -318,7 +319,7 @@ func TestMySQLIntegration(t *testing.T) {
Value: value,
}
err := mys.Set(setReq)
err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err)
})
@ -338,7 +339,7 @@ func TestMySQLIntegration(t *testing.T) {
ETag: &eTag,
}
err := mys.Delete(deleteReq)
err := mys.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
})
@ -349,7 +350,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "",
}
err := mys.Delete(deleteReq)
err := mys.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
})
@ -361,7 +362,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: randomKey(),
}
err := mys.Delete(deleteReq)
err := mys.Delete(context.TODO(), deleteReq)
assert.Nil(t, err)
})
@ -378,7 +379,7 @@ func TestMySQLIntegration(t *testing.T) {
},
}
err := mys.Set(setReq)
err := mys.Set(context.TODO(), setReq)
assert.NoError(t, err)
// Get the etag
@ -396,7 +397,7 @@ func TestMySQLIntegration(t *testing.T) {
},
}
err = mys.Set(setReq)
err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "Duplicate entry")
// Insert with invalid etag should fail on existing keys
@ -409,7 +410,7 @@ func TestMySQLIntegration(t *testing.T) {
},
}
err = mys.Set(setReq)
err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "possible etag mismatch")
// Insert with valid etag should succeed on existing keys
@ -422,7 +423,7 @@ func TestMySQLIntegration(t *testing.T) {
},
}
err = mys.Set(setReq)
err = mys.Set(context.TODO(), setReq)
assert.NoError(t, err)
// Insert with an etag should fail on new keys
@ -435,7 +436,7 @@ func TestMySQLIntegration(t *testing.T) {
},
}
err = mys.Set(setReq)
err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "possible etag mismatch")
})
@ -474,7 +475,7 @@ func TestMySQLIntegration(t *testing.T) {
})
}
err := mys.Multi(&state.TransactionalStateRequest{
err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -510,7 +511,7 @@ func TestMySQLIntegration(t *testing.T) {
})
}
err := mys.Multi(&state.TransactionalStateRequest{
err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -537,7 +538,7 @@ func TestMySQLIntegration(t *testing.T) {
})
}
err := mys.Multi(&state.TransactionalStateRequest{
err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -562,7 +563,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
},
}
err := mys.BulkSet(setReq)
err := mys.BulkSet(context.TODO(), setReq)
assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key))
@ -576,7 +577,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
},
}
err = mys.BulkDelete(deleteReq)
err = mys.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key))
@ -596,7 +597,7 @@ func setItem(t *testing.T, mys *MySQL, key string, value interface{}, eTag *stri
Value: value,
}
err := mys.Set(setReq)
err := mys.Set(context.TODO(), setReq)
assert.Nil(t, err, "Error setting an item")
itemExists := storeItemExists(t, key)
assert.True(t, itemExists, "Item does not exist after being set")
@ -608,7 +609,7 @@ func getItem(t *testing.T, mys *MySQL, key string) (*state.GetResponse, *fakeIte
Options: state.GetStateOption{},
}
response, getErr := mys.Get(getReq)
response, getErr := mys.Get(context.TODO(), getReq)
assert.Nil(t, getErr)
assert.NotNil(t, response)
outputObject := &fakeItem{}
@ -624,7 +625,7 @@ func deleteItem(t *testing.T, mys *MySQL, key string, eTag *string) {
Options: state.DeleteStateOption{},
}
deleteErr := mys.Delete(deleteReq)
deleteErr := mys.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr, "There was an error deleting a record")
assert.False(t, storeItemExists(t, key), "Item still exists after delete")
}

View File

@ -15,6 +15,7 @@ limitations under the License.
package mysql
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
@ -161,7 +162,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) {
m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError"))
// Act
err := m.mySQL.Multi(nil)
err := m.mySQL.Multi(context.TODO(), nil)
// Assert
assert.NotNil(t, err, "no error returned")
@ -180,7 +181,7 @@ func TestMySQLBulkDeleteRollbackDeletes(t *testing.T) {
deletes := []state.DeleteRequest{createDeleteRequest()}
// Act
err := m.mySQL.BulkDelete(deletes)
err := m.mySQL.BulkDelete(context.TODO(), deletes)
// Assert
assert.NotNil(t, err, "no error returned")
@ -199,7 +200,7 @@ func TestMySQLBulkSetRollbackSets(t *testing.T) {
sets := []state.SetRequest{createSetRequest()}
// Act
err := m.mySQL.BulkSet(sets)
err := m.mySQL.BulkSet(context.TODO(), sets)
// Assert
assert.NotNil(t, err, "no error returned")
@ -232,7 +233,7 @@ func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) {
}
// Act
err := m.mySQL.Multi(&request)
err := m.mySQL.Multi(context.TODO(), &request)
// Assert
assert.Nil(t, err, "error returned")
@ -248,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) {
request.Options.Consistency = "Invalid"
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -263,7 +264,7 @@ func TestSetHandlesNoKey(t *testing.T) {
request.Key = ""
// Act
err := m.mySQL.Set(&request)
err := m.mySQL.Set(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -283,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) {
request.ETag = &eTag
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.Nil(t, err)
@ -302,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -315,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest()
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -327,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest()
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.Nil(t, err)
@ -338,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest()
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -352,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag
// Act
err := m.mySQL.setValue(&request)
err := m.mySQL.setValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -369,7 +370,7 @@ func TestMySQLDeleteHandlesNoKey(t *testing.T) {
request.Key = ""
// Act
err := m.mySQL.Delete(&request)
err := m.mySQL.Delete(context.TODO(), &request)
// Asset
assert.NotNil(t, err)
@ -388,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) {
request.ETag = &eTag
// Act
err := m.mySQL.deleteValue(&request)
err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert
assert.Nil(t, err)
@ -405,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) {
request := createDeleteRequest()
// Act
err := m.mySQL.deleteValue(&request)
err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -420,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) {
request.ETag = &eTag
// Act
err := m.mySQL.deleteValue(&request)
err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert
assert.NotNil(t, err)
@ -441,7 +442,7 @@ func TestGetHandlesNoRows(t *testing.T) {
}
// Act
response, err := m.mySQL.Get(request)
response, err := m.mySQL.Get(context.TODO(), request)
// Assert
assert.Nil(t, err, "returned error")
@ -458,7 +459,7 @@ func TestGetHandlesNoKey(t *testing.T) {
}
// Act
response, err := m.mySQL.Get(request)
response, err := m.mySQL.Get(context.TODO(), request)
// Assert
assert.NotNil(t, err, "returned error")
@ -478,7 +479,7 @@ func TestGetHandlesGenericError(t *testing.T) {
}
// Act
response, err := m.mySQL.Get(request)
response, err := m.mySQL.Get(context.TODO(), request)
// Assert
assert.NotNil(t, err)
@ -499,7 +500,7 @@ func TestGetSucceeds(t *testing.T) {
}
// Act
response, err := m.mySQL.Get(request)
response, err := m.mySQL.Get(context.TODO(), request)
// Assert
assert.Nil(t, err)
@ -517,7 +518,7 @@ func TestGetSucceeds(t *testing.T) {
}
// Act
response, err := m.mySQL.Get(request)
response, err := m.mySQL.Get(context.TODO(), request)
// Assert
assert.Nil(t, err)
@ -711,7 +712,7 @@ func TestBulkGetReturnsNil(t *testing.T) {
m, _ := mockDatabase(t)
// Act
supported, response, err := m.mySQL.BulkGet(nil)
supported, response, err := m.mySQL.BulkGet(context.TODO(), nil)
// Assert
assert.Nil(t, err, `returned err`)
@ -730,7 +731,7 @@ func TestMultiWithNoRequestsDoesNothing(t *testing.T) {
m.mock1.ExpectCommit()
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -750,7 +751,7 @@ func TestInvalidMultiAction(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -789,7 +790,7 @@ func TestValidSetRequest(t *testing.T) {
m.mock1.ExpectCommit()
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -810,7 +811,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -834,7 +835,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -858,7 +859,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -879,7 +880,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -902,7 +903,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
})
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})
@ -941,7 +942,7 @@ func TestMultiOperationOrder(t *testing.T) {
m.mock1.ExpectCommit()
// Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{
err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops,
})

View File

@ -131,15 +131,15 @@ func (r *StateStore) Features() []state.Feature {
return r.features
}
func (r *StateStore) Delete(req *state.DeleteRequest) error {
func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("Delete entry from OCI Object Storage State Store with key ", req.Key)
err := r.deleteDocument(req)
err := r.deleteDocument(ctx, req)
return err
}
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("Get from OCI Object Storage State Store with key ", req.Key)
content, etag, err := r.readDocument((req))
content, etag, err := r.readDocument(ctx, req)
if err != nil {
r.logger.Debugf("error %s", err)
if err.Error() == "ObjectNotFound" {
@ -155,9 +155,9 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, err
}
func (r *StateStore) Set(req *state.SetRequest) error {
func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s to OCI Object Storage State Store", req.Key)
return r.writeDocument(req)
return r.writeDocument(ctx, req)
}
func (r *StateStore) Ping() error {
@ -267,7 +267,7 @@ func getIdentityAuthenticationDetails(metadata map[string]string, meta *Metadata
}
// functions that bridge from the Dapr State API to the OCI ObjectStorage Client.
func (r *StateStore) writeDocument(req *state.SetRequest) error {
func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) error {
if len(req.Key) == 0 || req.Key == "" {
return fmt.Errorf("key for value to set was missing from request")
}
@ -286,7 +286,6 @@ func (r *StateStore) writeDocument(req *state.SetRequest) error {
objectName := getFileName(req.Key)
content := r.marshal(req)
objectLength := int64(len(content))
ctx := context.Background()
etag := req.ETag
if req.Options.Concurrency != state.FirstWrite {
etag = nil
@ -315,12 +314,11 @@ func (r *StateStore) convertTTLtoExpiryTime(req *state.SetRequest, metadata map[
return nil
}
func (r *StateStore) readDocument(req *state.GetRequest) ([]byte, *string, error) {
func (r *StateStore) readDocument(ctx context.Context, req *state.GetRequest) ([]byte, *string, error) {
if len(req.Key) == 0 || req.Key == "" {
return nil, nil, fmt.Errorf("key for value to get was missing from request")
}
objectName := getFileName(req.Key)
ctx := context.Background()
content, etag, meta, err := r.client.getObject(ctx, objectName)
if err != nil {
r.logger.Debugf("download file %s, err %s", req.Key, err)
@ -348,13 +346,12 @@ func (r *StateStore) pingBucket() error {
return nil
}
func (r *StateStore) deleteDocument(req *state.DeleteRequest) error {
func (r *StateStore) deleteDocument(ctx context.Context, req *state.DeleteRequest) error {
if len(req.Key) == 0 || req.Key == "" {
return fmt.Errorf("key for value to delete was missing from request")
}
objectName := getFileName(req.Key)
ctx := context.Background()
etag := req.ETag
if req.Options.Concurrency != state.FirstWrite {
etag = nil

View File

@ -4,6 +4,7 @@ package objectstorage
// go test -v github.com/dapr/components-contrib/state/oci/objectstorage.
import (
"context"
"fmt"
"os"
"testing"
@ -88,16 +89,16 @@ func testGet(t *testing.T, ociProperties map[string]string) {
t.Run("Get an non-existing key", func(t *testing.T) {
err := statestore.Init(meta)
assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "xyzq"})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "xyzq"})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Non-existing key must not be treated as error")
})
t.Run("Get an existing key", func(t *testing.T) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "test-key", Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-key", Value: []byte("test-value")})
assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "test-key"})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-key"})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set")
@ -105,9 +106,9 @@ func testGet(t *testing.T, ociProperties map[string]string) {
t.Run("Get an existing composed key", func(t *testing.T) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")})
assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "test-app||test-key"})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
})
@ -115,11 +116,11 @@ func testGet(t *testing.T, ociProperties map[string]string) {
testKey := "unexpired-ttl-test-key"
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "100",
})})
assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set despite TTL setting")
})
@ -127,23 +128,23 @@ func testGet(t *testing.T, ociProperties map[string]string) {
testKey := "never-expiring-ttl-test-key"
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "-1",
})})
assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)")
})
t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "1",
})})
assert.Nil(t, err)
time.Sleep(time.Second * 2)
getResponse, err := statestore.Get(&state.GetRequest{Key: "ttl-test-key"})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "ttl-test-key"})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Expired element must not be treated as error")
})
@ -156,7 +157,7 @@ func testSet(t *testing.T, ociProperties map[string]string) {
t.Run("Set without a key", func(t *testing.T) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")})
assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error")
})
t.Run("Regular Set Operation", func(t *testing.T) {
@ -164,9 +165,9 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set")
@ -176,20 +177,20 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree")
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey})
getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set")
})
t.Run("Regular Set Operation with TTL", func(t *testing.T) {
testKey := "test-key-with-ttl"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "500",
})})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "XXX",
})})
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
@ -200,25 +201,25 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta)
assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, _ := statestore.Get(&state.GetRequest{Key: testKey})
getResponse, _ := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
etag := getResponse.ETag
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.Nil(t, err, "Updating value with proper etag should go fine")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Updating value with the old etag should be refused")
// retrieve the latest etag - assigned by the previous set operation.
getResponse, _ = statestore.Get(&state.GetRequest{Key: testKey})
getResponse, _ = statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.NotNil(t, *getResponse.ETag, "ETag should be set")
etag = getResponse.ETag
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.Nil(t, err, "Updating value with the latest etag should be accepted")
@ -232,7 +233,7 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
t.Run("Delete without a key", func(t *testing.T) {
err := s.Init(m)
assert.Nil(t, err)
err = s.Delete(&state.DeleteRequest{})
err = s.Delete(context.TODO(), &state.DeleteRequest{})
assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error")
})
t.Run("Regular Delete Operation", func(t *testing.T) {
@ -240,9 +241,9 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m)
assert.Nil(t, err)
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
err = s.Delete(&state.DeleteRequest{Key: testKey})
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree")
})
t.Run("Regular Delete Operation for composite key", func(t *testing.T) {
@ -250,13 +251,13 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m)
assert.Nil(t, err)
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree")
err = s.Delete(&state.DeleteRequest{Key: testKey})
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper composite key should be errorfree")
})
t.Run("Delete with an unknown key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{Key: "unknownKey"})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"})
assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found")
})
@ -265,18 +266,18 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m)
assert.Nil(t, err)
// create document.
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, _ := s.Get(&state.GetRequest{Key: testKey})
getResponse, _ := s.Get(context.TODO(), &state.GetRequest{Key: testKey})
etag := getResponse.ETag
incorrectETag := "someRandomETag"
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented")
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite,
}})
assert.Nil(t, err, "Deleting value with proper etag should go fine")

View File

@ -236,24 +236,24 @@ func TestGetWithMockClient(t *testing.T) {
s.client = mockClient
t.Parallel()
t.Run("Test regular Get", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-key"})
getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-key"})
assert.True(t, mockClient.getIsCalled, "function Get should be invoked on the mockClient")
assert.Equal(t, "Hello World", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set")
assert.Nil(t, err)
})
t.Run("Test Get with composite key", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-app||test-key"})
getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"})
assert.Equal(t, "Hello Continent", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.Nil(t, err)
})
t.Run("Test Get with an unknown key", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "unknownKey"})
getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "unknownKey"})
assert.Nil(t, getResponse.Data, "No value should be retrieved for an unknown key")
assert.Nil(t, err, "404", "Not finding an object because of unknown key should not result in an error")
})
t.Run("Test expired element (because of TTL) ", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-expired-ttl-key"})
getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-expired-ttl-key"})
assert.Nil(t, getResponse.Data, "No value should be retrieved for an expired state element")
assert.Nil(t, err, "Not returning an object because of expiration should not result in an error")
})
@ -289,28 +289,28 @@ func TestSetWithMockClient(t *testing.T) {
mockClient := &mockedObjectStoreClient{}
statestore.client = mockClient
t.Run("Set without a key", func(t *testing.T) {
err := statestore.Set(&state.SetRequest{Value: []byte("test-value")})
err := statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")})
assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error")
})
t.Run("Regular Set Operation", func(t *testing.T) {
testKey := "test-key"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")})
err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
assert.True(t, mockClient.putIsCalled, "function put should be invoked on the mockClient")
})
t.Run("Regular Set Operation with TTL", func(t *testing.T) {
testKey := "test-key"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "5",
})})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "XXX",
})})
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "1",
})})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
@ -320,22 +320,22 @@ func TestSetWithMockClient(t *testing.T) {
incorrectETag := "notTheCorrectETag"
etag := "correctETag"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{
err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Updating value with wrong etag should fail")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.Nil(t, err, "Updating value with proper etag should go fine")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Updating value with concurrency policy at FirstWrite should fail when ETag is missing")
@ -348,34 +348,34 @@ func TestDeleteWithMockClient(t *testing.T) {
mockClient := &mockedObjectStoreClient{}
s.client = mockClient
t.Run("Delete without a key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{})
err := s.Delete(context.TODO(), &state.DeleteRequest{})
assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error")
})
t.Run("Delete with an unknown key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{Key: "unknownKey"})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"})
assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found")
})
t.Run("Regular Delete Operation", func(t *testing.T) {
testKey := "test-key"
err := s.Delete(&state.DeleteRequest{Key: testKey})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree")
assert.True(t, mockClient.deleteIsCalled, "function delete should be invoked on the mockClient")
})
t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) {
testKey := "etag-test-delete-key"
incorrectETag := "notTheCorrectETag"
err := s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented")
etag := "correctETag"
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite,
}})
assert.Nil(t, err, "Deleting value with proper etag should go fine")
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{
err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite,
}})
assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail")

View File

@ -14,6 +14,8 @@ limitations under the License.
package oracledatabase
import (
"context"
"github.com/dapr/components-contrib/state"
)
@ -21,9 +23,9 @@ import (
type dbAccess interface {
Init(metadata state.Metadata) error
Ping() error
Set(req *state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error
ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error
Set(ctx context.Context, req *state.SetRequest) error
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(ctx context.Context, req *state.DeleteRequest) error
ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error
Close() error // io.Closer.
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package oracledatabase
import (
"context"
"fmt"
"github.com/dapr/components-contrib/state"
@ -59,38 +60,38 @@ func (o *OracleDatabase) Features() []state.Feature {
}
// Delete removes an entity from the store.
func (o *OracleDatabase) Delete(req *state.DeleteRequest) error {
return o.dbaccess.Delete(req)
func (o *OracleDatabase) Delete(ctx context.Context, req *state.DeleteRequest) error {
return o.dbaccess.Delete(ctx, req)
}
// BulkDelete removes multiple entries from the store.
func (o *OracleDatabase) BulkDelete(req []state.DeleteRequest) error {
return o.dbaccess.ExecuteMulti(nil, req)
func (o *OracleDatabase) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return o.dbaccess.ExecuteMulti(ctx, nil, req)
}
// Get returns an entity from store.
func (o *OracleDatabase) Get(req *state.GetRequest) (*state.GetResponse, error) {
return o.dbaccess.Get(req)
func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return o.dbaccess.Get(ctx, req)
}
// BulkGet performs a bulks get operations.
func (o *OracleDatabase) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance.
return false, nil, nil
}
// Set adds/updates an entity on store.
func (o *OracleDatabase) Set(req *state.SetRequest) error {
return o.dbaccess.Set(req)
func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error {
return o.dbaccess.Set(ctx, req)
}
// BulkSet adds/updates multiple entities on store.
func (o *OracleDatabase) BulkSet(req []state.SetRequest) error {
return o.dbaccess.ExecuteMulti(req, nil)
func (o *OracleDatabase) BulkSet(ctx context.Context, req []state.SetRequest) error {
return o.dbaccess.ExecuteMulti(ctx, req, nil)
}
// Multi handles multiple transactions. Implements TransactionalStore.
func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error {
func (o *OracleDatabase) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
var deletes []state.DeleteRequest
var sets []state.SetRequest
for _, req := range request.Operations {
@ -115,7 +116,7 @@ func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error {
}
if len(sets) > 0 || len(deletes) > 0 {
return o.dbaccess.ExecuteMulti(sets, deletes)
return o.dbaccess.ExecuteMulti(ctx, sets, deletes)
}
return nil

View File

@ -15,6 +15,7 @@ limitations under the License.
package oracledatabase
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -223,7 +224,7 @@ func deleteItemThatDoesNotExist(t *testing.T, ods *OracleDatabase) {
deleteReq := &state.DeleteRequest{
Key: randomKey(),
}
err := ods.Delete(deleteReq)
err := ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err)
}
@ -242,7 +243,7 @@ func multiWithSetOnly(t *testing.T, ods *OracleDatabase) {
})
}
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -272,7 +273,7 @@ func multiWithDeleteOnly(t *testing.T, ods *OracleDatabase) {
})
}
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -315,7 +316,7 @@ func multiWithDeleteAndSet(t *testing.T, ods *OracleDatabase) {
})
}
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -345,7 +346,7 @@ func deleteWithInvalidEtagFails(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite,
},
}
err := ods.Delete(deleteReq)
err := ods.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err, "Deleting an item with the wrong etag while enforcing FirstWrite policy should fail")
}
@ -353,7 +354,7 @@ func deleteWithNoKeyFails(t *testing.T, ods *OracleDatabase) {
deleteReq := &state.DeleteRequest{
Key: "",
}
err := ods.Delete(deleteReq)
err := ods.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
}
@ -371,7 +372,7 @@ func newItemWithEtagFails(t *testing.T, ods *OracleDatabase) {
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -401,7 +402,7 @@ func updateWithOldEtagFails(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite,
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -423,7 +424,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite,
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err, "Setting the item should be successful")
updateResponse, updatedItem := getItem(t, ods, key)
assert.Equal(t, value, updatedItem)
@ -439,7 +440,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite,
},
}
err = ods.Delete(deleteReq)
err = ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err, "Deleting an item with the right etag while enforcing FirstWrite policy should succeed")
// Item is not in the data store.
@ -465,7 +466,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac
Concurrency: state.LastWrite,
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err, "Setting the item should be successful")
_, updatedItem := getItem(t, ods, key)
assert.Equal(t, value, updatedItem)
@ -478,7 +479,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac
Concurrency: state.LastWrite,
},
}
err = ods.Delete(deleteReq)
err = ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err, "Deleting an item with the wrong etag but not enforcing FirstWrite policy should succeed")
// Item is not in the data store.
@ -500,7 +501,7 @@ func getItemWithNoKey(t *testing.T, ods *OracleDatabase) {
Key: "",
}
response, getErr := ods.Get(getReq)
response, getErr := ods.Get(context.TODO(), getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
}
@ -548,7 +549,7 @@ func setTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) {
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
connectionString := getConnectionString()
if getWalletLocation() != "" {
@ -580,10 +581,10 @@ func setNoTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "1000",
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
delete(setReq.Metadata, "ttlInSeconds")
err = ods.Set(setReq)
err = ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
connectionString := getConnectionString()
if getWalletLocation() != "" {
@ -614,11 +615,11 @@ func expiredStateCannotBeRead(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "1",
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
time.Sleep(time.Second * time.Duration(2))
getResponse, err := ods.Get(&state.GetRequest{Key: key})
getResponse, err := ods.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Expired element must not be treated as error")
@ -639,7 +640,7 @@ func unexpiredStateCanBeRead(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "10000",
},
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
_, getValue := getItem(t, ods, key)
assert.Equal(t, value.Color, getValue.Color, "Response must be as set")
@ -653,7 +654,7 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) {
Key: "",
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -702,7 +703,7 @@ func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "XX",
}),
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
}
@ -714,7 +715,7 @@ func testSetItemWithNegativeTTL(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "-10",
}),
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Setting a value with a proper key and a negative (other than -1) TTL value should be produce an error")
}
@ -731,7 +732,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) {
},
}
err := ods.BulkSet(setReq)
err := ods.BulkSet(context.TODO(), setReq)
assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key))
@ -745,7 +746,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) {
},
}
err = ods.BulkDelete(deleteReq)
err = ods.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key))
@ -812,7 +813,7 @@ func setItem(t *testing.T, ods *OracleDatabase, key string, value interface{}, e
Options: setOptions,
}
err := ods.Set(setReq)
err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err)
itemExists := storeItemExists(t, key)
assert.True(t, itemExists, "Item should exist after set has been executed ")
@ -824,7 +825,7 @@ func getItem(t *testing.T, ods *OracleDatabase, key string) (*state.GetResponse,
Options: state.GetStateOption{},
}
response, getErr := ods.Get(getReq)
response, getErr := ods.Get(context.TODO(), getReq)
assert.Nil(t, getErr)
assert.NotNil(t, response)
outputObject := &fakeItem{}
@ -840,7 +841,7 @@ func deleteItem(t *testing.T, ods *OracleDatabase, key string, etag *string) {
Options: state.DeleteStateOption{},
}
deleteErr := ods.Delete(deleteReq)
deleteErr := ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key), "item should no longer exist after delete has been performed")
}

View File

@ -15,6 +15,7 @@ limitations under the License.
package oracledatabase
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -48,23 +49,23 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil
}
func (m *fakeDBaccess) Set(req *state.SetRequest) error {
func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true
return nil
}
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true
return nil, nil
}
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error {
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return nil
}
func (m *fakeDBaccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error {
return nil
}
@ -84,7 +85,7 @@ func TestMultiWithNoRequestsReturnsNil(t *testing.T) {
t.Parallel()
var operations []state.TransactionalStateOperation
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -100,7 +101,7 @@ func TestInvalidMultiAction(t *testing.T) {
})
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.NotNil(t, err)
@ -116,7 +117,7 @@ func TestValidSetRequest(t *testing.T) {
})
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -132,7 +133,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
})
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.NotNil(t, err)
@ -148,7 +149,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
})
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -164,7 +165,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
})
ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{
err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.NotNil(t, err)

View File

@ -14,6 +14,7 @@ limitations under the License.
package oracledatabase
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
@ -96,8 +97,8 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
}
// Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(o.setValue, req)
func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, o.setValue, req)
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
@ -115,7 +116,7 @@ func parseTTL(requestMetadata map[string]string) (*int, error) {
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
func (o *oracleDatabaseAccess) setValue(ctx context.Context, req *state.SetRequest) error {
o.logger.Debug("Setting state value in OracleDatabase")
err := state.CheckRequestOptions(req.Options)
if err != nil {
@ -182,14 +183,14 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `,
tableName)
result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
} else {
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
updateStatement := fmt.Sprintf(
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
WHERE key = :key AND etag = :etag`,
tableName)
result, err = tx.Exec(updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
}
if err != nil {
if req.ETag != nil && *req.ETag != "" {
@ -214,7 +215,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
}
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
o.logger.Debug("Getting state value from OracleDatabase")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
@ -222,7 +223,7 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
var value string
var binaryYN string
var etag string
err := o.db.QueryRow(fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
@ -253,12 +254,12 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
}
// Delete removes an item from the state store.
func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(o.deleteValue, req)
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, o.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
func (o *oracleDatabaseAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
o.logger.Debug("Deleting state value from OracleDatabase")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
@ -280,9 +281,9 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
}
// QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied??
if req.Options.Concurrency != state.FirstWrite {
result, err = tx.Exec("DELETE FROM state WHERE key = :key", req.Key)
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key)
} else {
result, err = tx.Exec("DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
}
if err != nil {
if o.tx == nil { // not joining a preexisting transaction.
@ -303,7 +304,7 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
return nil
}
func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
func (o *oracleDatabaseAccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error {
o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction")
tx, err := o.db.Begin()
if err != nil {
@ -313,7 +314,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
if len(deletes) > 0 {
for _, d := range deletes {
da := d // Fix for gosec G601: Implicit memory aliasing in for looo.
err = o.Delete(&da)
err = o.Delete(ctx, &da)
if err != nil {
tx.Rollback()
return err
@ -323,7 +324,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
if len(sets) > 0 {
for _, s := range sets {
sa := s // Fix for gosec G601: Implicit memory aliasing in for looo.
err = o.Set(&sa)
err = o.Set(ctx, &sa)
if err != nil {
tx.Rollback()
return err

View File

@ -14,18 +14,20 @@ limitations under the License.
package postgresql
import (
"context"
"github.com/dapr/components-contrib/state"
)
// dbAccess is a private interface which enables unit testing of PostgreSQL.
type dbAccess interface {
Init(metadata state.Metadata) error
Set(req *state.SetRequest) error
BulkSet(req []state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error
BulkDelete(req []state.DeleteRequest) error
ExecuteMulti(req *state.TransactionalStateRequest) error
Query(req *state.QueryRequest) (*state.QueryResponse, error)
Set(ctx context.Context, req *state.SetRequest) error
BulkSet(ctx context.Context, req []state.SetRequest) error
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(ctx context.Context, req *state.DeleteRequest) error
BulkDelete(ctx context.Context, req []state.DeleteRequest) error
ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error
Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
Close() error // io.Closer
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package postgresql
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
@ -98,12 +99,12 @@ func (p *postgresDBAccess) Init(metadata state.Metadata) error {
}
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
func (p *postgresDBAccess) setValue(ctx context.Context, req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options)
@ -134,7 +135,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.ETag == nil {
result, err = p.db.Exec(fmt.Sprintf(
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
tableName), req.Key, value, isBinary)
@ -148,7 +149,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
etag := uint32(etag64)
// When an etag is provided do an update - no insert
result, err = p.db.Exec(fmt.Sprintf(
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW()
WHERE key = $3 AND xmin = $4;`,
tableName), value, isBinary, req.Key, etag)
@ -174,7 +175,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
return nil
}
func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin()
if err != nil {
@ -184,7 +185,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
if len(req) > 0 {
for _, s := range req {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(&sa)
err = p.Set(ctx, &sa)
if err != nil {
tx.Rollback()
@ -199,7 +200,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
}
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
@ -208,7 +209,7 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
var value string
var isBinary bool
var etag int
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
@ -245,12 +246,12 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
}
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req)
func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
func (p *postgresDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
@ -260,7 +261,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
var err error
if req.ETag == nil {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64
@ -270,7 +271,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
}
etag := uint32(etag64)
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
}
if err != nil {
@ -289,7 +290,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
return nil
}
func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin()
if err != nil {
@ -299,7 +300,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 {
for _, d := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da)
err = p.Delete(ctx, &da)
if err != nil {
tx.Rollback()
@ -313,7 +314,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
return err
}
func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error {
func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin()
@ -332,7 +333,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
return err
}
err = p.Set(&setReq)
err = p.Set(ctx, &setReq)
if err != nil {
tx.Rollback()
return err
@ -347,7 +348,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
return err
}
err = p.Delete(&delReq)
err = p.Delete(ctx, &delReq)
if err != nil {
tx.Rollback()
return err
@ -365,7 +366,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
}
// Query executes a query against store.
func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from PostgreSQL")
q := &Query{
query: "",
@ -375,7 +376,7 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse,
if err := qbuilder.BuildQuery(&req.Query); err != nil {
return &state.QueryResponse{}, err
}
data, token, err := q.execute(p.logger, p.db)
data, token, err := q.execute(ctx, p.logger, p.db)
if err != nil {
return &state.QueryResponse{}, err
}

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql
import (
"context"
"database/sql"
"testing"
@ -110,7 +111,7 @@ func TestMultiWithNoRequests(t *testing.T) {
var operations []state.TransactionalStateOperation
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -134,7 +135,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -159,7 +160,7 @@ func TestValidSetRequest(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -183,7 +184,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -207,7 +208,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -232,7 +233,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -256,7 +257,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -280,7 +281,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
})
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -312,7 +313,7 @@ func TestMultiOperationOrder(t *testing.T) {
)
// Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{
err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
@ -335,7 +336,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
})
// Act
err := m.pgDba.BulkSet(sets)
err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert
assert.NotNil(t, err)
@ -357,7 +358,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
})
// Act
err := m.pgDba.BulkSet(sets)
err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert
assert.NotNil(t, err)
@ -380,7 +381,7 @@ func TestValidBulkSet(t *testing.T) {
})
// Act
err := m.pgDba.BulkSet(sets)
err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert
assert.Nil(t, err)
@ -401,7 +402,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
})
// Act
err := m.pgDba.BulkDelete(deletes)
err := m.pgDba.BulkDelete(context.TODO(), deletes)
// Assert
assert.NotNil(t, err)
@ -423,7 +424,7 @@ func TestValidBulkDelete(t *testing.T) {
})
// Act
err := m.pgDba.BulkDelete(deletes)
err := m.pgDba.BulkDelete(context.TODO(), deletes)
// Assert
assert.Nil(t, err)

View File

@ -14,6 +14,8 @@ limitations under the License.
package postgresql
import (
"context"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
)
@ -53,44 +55,44 @@ func (p *PostgreSQL) Features() []state.Feature {
}
// Delete removes an entity from the store.
func (p *PostgreSQL) Delete(req *state.DeleteRequest) error {
return p.dbaccess.Delete(req)
func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
return p.dbaccess.Delete(ctx, req)
}
// BulkDelete removes multiple entries from the store.
func (p *PostgreSQL) BulkDelete(req []state.DeleteRequest) error {
return p.dbaccess.BulkDelete(req)
func (p *PostgreSQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return p.dbaccess.BulkDelete(ctx, req)
}
// Get returns an entity from store.
func (p *PostgreSQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
return p.dbaccess.Get(req)
func (p *PostgreSQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return p.dbaccess.Get(ctx, req)
}
// BulkGet performs a bulks get operations.
func (p *PostgreSQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (p *PostgreSQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance
return false, nil, nil
}
// Set adds/updates an entity on store.
func (p *PostgreSQL) Set(req *state.SetRequest) error {
return p.dbaccess.Set(req)
func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error {
return p.dbaccess.Set(ctx, req)
}
// BulkSet adds/updates multiple entities on store.
func (p *PostgreSQL) BulkSet(req []state.SetRequest) error {
return p.dbaccess.BulkSet(req)
func (p *PostgreSQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
return p.dbaccess.BulkSet(ctx, req)
}
// Multi handles multiple transactions. Implements TransactionalStore.
func (p *PostgreSQL) Multi(request *state.TransactionalStateRequest) error {
return p.dbaccess.ExecuteMulti(request)
func (p *PostgreSQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
return p.dbaccess.ExecuteMulti(ctx, request)
}
// Query executes a query against store.
func (p *PostgreSQL) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
return p.dbaccess.Query(req)
func (p *PostgreSQL) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return p.dbaccess.Query(ctx, req)
}
// Close implements io.Closer.

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -192,7 +193,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
deleteReq := &state.DeleteRequest{
Key: randomKey(),
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, err)
}
@ -211,7 +212,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -241,7 +242,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -284,7 +285,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
})
}
err := pgs.Multi(&state.TransactionalStateRequest{
err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
@ -311,7 +312,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *PostgreSQL) {
Key: key,
ETag: &etag,
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
}
@ -319,7 +320,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *PostgreSQL) {
deleteReq := &state.DeleteRequest{
Key: "",
}
err := pgs.Delete(deleteReq)
err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err)
}
@ -334,7 +335,7 @@ func newItemWithEtagFails(t *testing.T, pgs *PostgreSQL) {
Value: value,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -360,7 +361,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *PostgreSQL) {
ETag: originalEtag,
Value: newValue,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -402,7 +403,7 @@ func getItemWithNoKey(t *testing.T, pgs *PostgreSQL) {
Key: "",
}
response, getErr := pgs.Get(getReq)
response, getErr := pgs.Get(context.TODO(), getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
}
@ -433,7 +434,7 @@ func setItemWithNoKey(t *testing.T, pgs *PostgreSQL) {
Key: "",
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err)
}
@ -450,7 +451,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
},
}
err := pgs.BulkSet(setReq)
err := pgs.BulkSet(context.TODO(), setReq)
assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key))
@ -464,7 +465,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
},
}
err = pgs.BulkDelete(deleteReq)
err = pgs.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key))
@ -521,7 +522,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
Value: value,
}
err := pgs.Set(setReq)
err := pgs.Set(context.TODO(), setReq)
assert.Nil(t, err)
itemExists := storeItemExists(t, key)
assert.True(t, itemExists)
@ -533,7 +534,7 @@ func getItem(t *testing.T, pgs *PostgreSQL, key string) (*state.GetResponse, *fa
Options: state.GetStateOption{},
}
response, getErr := pgs.Get(getReq)
response, getErr := pgs.Get(context.TODO(), getReq)
assert.Nil(t, getErr)
assert.NotNil(t, response)
outputObject := &fakeItem{}
@ -549,7 +550,7 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
Options: state.DeleteStateOption{},
}
deleteErr := pgs.Delete(deleteReq)
deleteErr := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key))
}

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql
import (
"context"
"database/sql"
"fmt"
"strconv"
@ -139,8 +140,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
return nil
}
func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.Query(q.query, q.params...)
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.QueryContext(ctx, q.query, q.params...)
if err != nil {
return nil, "", err
}

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -43,37 +44,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil
}
func (m *fakeDBaccess) Set(req *state.SetRequest) error {
func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true
return nil
}
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true
return nil, nil
}
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error {
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
m.deleteExecuted = true
return nil
}
func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error {
func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
return nil
}
func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error {
func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return nil
}
func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error {
func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error {
return nil
}
func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return nil, nil
}

View File

@ -195,7 +195,7 @@ func (r *StateStore) parseConnectedSlaves(res string) int {
return 0
}
func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
func (r *StateStore) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
if req.ETag == nil {
etag := "0"
req.ETag = &etag
@ -207,7 +207,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
} else {
delQuery = delDefaultQuery
}
_, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
_, err := r.client.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
if err != nil {
return state.NewETagError(state.ETagMismatch, err)
}
@ -216,17 +216,17 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
}
// Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
return state.DeleteWithOptions(r.deleteValue, req)
return state.DeleteWithOptions(ctx, r.deleteValue, req)
}
func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "GET", req.Key).Result()
func (r *StateStore) directGet(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(ctx, "GET", req.Key).Result()
if err != nil {
return nil, err
}
@ -242,10 +242,10 @@ func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error
}, nil
}
func (r *StateStore) getDefault(req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "HGETALL", req.Key).Result() // Prefer values with ETags
func (r *StateStore) getDefault(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(ctx, "HGETALL", req.Key).Result() // Prefer values with ETags
if err != nil {
return r.directGet(req) // Falls back to original get for backward compats.
return r.directGet(ctx, req) // Falls back to original get for backward compats.
}
if res == nil {
return &state.GetResponse{}, nil
@ -304,12 +304,12 @@ func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error)
}
// Get retrieves state from redis with a key.
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType {
return r.getJSON(req)
}
return r.getDefault(req)
return r.getDefault(ctx, req)
}
type jsonEntry struct {
@ -317,7 +317,7 @@ type jsonEntry struct {
Version *int `json:"version,omitempty"`
}
func (r *StateStore) setValue(req *state.SetRequest) error {
func (r *StateStore) setValue(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -350,7 +350,7 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
bt, _ = utils.Marshal(req.Value, r.json.Marshal)
}
err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err()
err = r.client.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err()
if err != nil {
if req.ETag != nil {
return state.NewETagError(state.ETagMismatch, err)
@ -360,21 +360,21 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
}
if ttl != nil && *ttl > 0 {
_, err = r.client.Do(r.ctx, "EXPIRE", req.Key, *ttl).Result()
_, err = r.client.Do(ctx, "EXPIRE", req.Key, *ttl).Result()
if err != nil {
return fmt.Errorf("failed to set key %s ttl: %s", req.Key, err)
}
}
if ttl != nil && *ttl <= 0 {
_, err = r.client.Do(r.ctx, "PERSIST", req.Key).Result()
_, err = r.client.Do(ctx, "PERSIST", req.Key).Result()
if err != nil {
return fmt.Errorf("failed to persist key %s: %s", req.Key, err)
}
}
if req.Options.Consistency == state.Strong && r.replicas > 0 {
_, err = r.client.Do(r.ctx, "WAIT", r.replicas, 1000).Result()
_, err = r.client.Do(ctx, "WAIT", r.replicas, 1000).Result()
if err != nil {
return fmt.Errorf("redis waiting for %v replicas to acknowledge write, err: %s", r.replicas, err.Error())
}
@ -384,12 +384,12 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
}
// Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error {
return state.SetWithOptions(r.setValue, req)
func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(ctx, r.setValue, req)
}
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
func (r *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
var setQuery, delQuery string
var isJSON bool
if contentType, ok := request.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType {
@ -423,12 +423,12 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
} else {
bt, _ = utils.Marshal(req.Value, r.json.Marshal)
}
pipe.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt)
pipe.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt)
if ttl != nil && *ttl > 0 {
pipe.Do(r.ctx, "EXPIRE", req.Key, *ttl)
pipe.Do(ctx, "EXPIRE", req.Key, *ttl)
}
if ttl != nil && *ttl <= 0 {
pipe.Do(r.ctx, "PERSIST", req.Key)
pipe.Do(ctx, "PERSIST", req.Key)
}
} else if o.Operation == state.Delete {
req := o.Request.(state.DeleteRequest)
@ -436,11 +436,11 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
etag := "0"
req.ETag = &etag
}
pipe.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag)
pipe.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag)
}
}
_, err := pipe.Exec(r.ctx)
_, err := pipe.Exec(ctx)
return err
}
@ -514,7 +514,7 @@ func (r *StateStore) parseTTL(req *state.SetRequest) (*int, error) {
}
// Query executes a query against store.
func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
func (r *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
indexName, ok := daprmetadata.TryGetQueryIndexName(req.Metadata)
if !ok {
return nil, fmt.Errorf("query index not found")
@ -529,7 +529,7 @@ func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error
if err := qbuilder.BuildQuery(&req.Query); err != nil {
return &state.QueryResponse{}, err
}
data, token, err := q.execute(r.ctx, r.client)
data, token, err := q.execute(ctx, r.client)
if err != nil {
return &state.QueryResponse{}, err
}

View File

@ -206,7 +206,7 @@ func TestTransactionalUpsert(t *testing.T) {
}
ss.ctx, ss.cancel = context.WithCancel(context.Background())
err := ss.Multi(&state.TransactionalStateRequest{
err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{
Operation: state.Upsert,
@ -273,13 +273,13 @@ func TestTransactionalDelete(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background())
// Insert a record first.
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon",
Value: "deathstar",
})
etag := "1"
err := ss.Multi(&state.TransactionalStateRequest{
err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{{
Operation: state.Delete,
Request: state.DeleteRequest{
@ -331,7 +331,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background())
t.Run("TTL: Only global specified", func(t *testing.T) {
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100",
Value: "deathstar100",
})
@ -342,7 +342,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
t.Run("TTL: Global and Request specified", func(t *testing.T) {
requestTTL := 200
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100",
Value: "deathstar100",
Metadata: map[string]string{
@ -355,7 +355,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
})
t.Run("TTL: Global and Request specified", func(t *testing.T) {
err := ss.Multi(&state.TransactionalStateRequest{
err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{
Operation: state.Upsert,
@ -424,7 +424,7 @@ func TestSetRequestWithTTL(t *testing.T) {
t.Run("TTL specified", func(t *testing.T) {
ttlInSeconds := 100
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100",
Value: "deathstar100",
Metadata: map[string]string{
@ -438,7 +438,7 @@ func TestSetRequestWithTTL(t *testing.T) {
})
t.Run("TTL not specified", func(t *testing.T) {
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon200",
Value: "deathstar200",
})
@ -449,7 +449,7 @@ func TestSetRequestWithTTL(t *testing.T) {
})
t.Run("TTL Changed for Existing Key", func(t *testing.T) {
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300",
Value: "deathstar300",
})
@ -458,7 +458,7 @@ func TestSetRequestWithTTL(t *testing.T) {
// make the key no longer persistent
ttlInSeconds := 123
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300",
Value: "deathstar300",
Metadata: map[string]string{
@ -469,7 +469,7 @@ func TestSetRequestWithTTL(t *testing.T) {
assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl)
// make the key persistent again
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300",
Value: "deathstar301",
Metadata: map[string]string{
@ -493,12 +493,12 @@ func TestTransactionalDeleteNoEtag(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background())
// Insert a record first.
ss.Set(&state.SetRequest{
ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100",
Value: "deathstar100",
})
err := ss.Multi(&state.TransactionalStateRequest{
err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{{
Operation: state.Delete,
Request: state.DeleteRequest{

View File

@ -14,6 +14,7 @@ limitations under the License.
package state
import (
"context"
"fmt"
)
@ -68,11 +69,11 @@ func validateConsistencyOption(c string) error {
}
// SetWithOptions handles SetRequest with request options.
func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error {
return method(req)
func SetWithOptions(ctx context.Context, method func(ctx context.Context, req *SetRequest) error, req *SetRequest) error {
return method(ctx, req)
}
// DeleteWithOptions handles DeleteRequest with options.
func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error {
return method(req)
func DeleteWithOptions(ctx context.Context, method func(ctx context.Context, req *DeleteRequest) error, req *DeleteRequest) error {
return method(ctx, req)
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package state
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -23,7 +24,7 @@ import (
func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with default options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error {
counter++
return nil
@ -33,7 +34,7 @@ func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with no explicit options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error {
counter++
return nil

View File

@ -14,6 +14,7 @@ limitations under the License.
package rethinkdb
import (
"context"
"encoding/json"
"io"
"strconv"
@ -147,7 +148,7 @@ func tableExists(arr []string, table string) bool {
}
// Get retrieves a RethinkDB KV item.
func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (s *RethinkDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if req == nil || req.Key == "" {
return nil, errors.New("invalid state request, missing key")
}
@ -187,22 +188,22 @@ func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// BulkGet performs a bulks get operations.
func (s *RethinkDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (s *RethinkDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with bulk get for performance
return false, nil, nil
}
// Set saves a state KV item.
func (s *RethinkDB) Set(req *state.SetRequest) error {
func (s *RethinkDB) Set(ctx context.Context, req *state.SetRequest) error {
if req == nil || req.Key == "" || req.Value == nil {
return errors.New("invalid state request, key and value required")
}
return s.BulkSet([]state.SetRequest{*req})
return s.BulkSet(ctx, []state.SetRequest{*req})
}
// BulkSet performs a bulk save operation.
func (s *RethinkDB) BulkSet(req []state.SetRequest) error {
func (s *RethinkDB) BulkSet(ctx context.Context, req []state.SetRequest) error {
docs := make([]*stateRecord, len(req))
for i, v := range req {
var etag string
@ -257,16 +258,16 @@ func (s *RethinkDB) archive(changes []r.ChangeResponse) error {
}
// Delete performes a RethinkDB KV delete operation.
func (s *RethinkDB) Delete(req *state.DeleteRequest) error {
func (s *RethinkDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
if req == nil || req.Key == "" {
return errors.New("invalid request, missing key")
}
return s.BulkDelete([]state.DeleteRequest{*req})
return s.BulkDelete(ctx, []state.DeleteRequest{*req})
}
// BulkDelete performs a bulk delete operation.
func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error {
func (s *RethinkDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
list := make([]string, 0)
for _, d := range req {
list = append(list, d.Key)
@ -282,7 +283,7 @@ func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error {
}
// Multi performs multiple operations.
func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error {
func (s *RethinkDB) Multi(ctx context.Context, req *state.TransactionalStateRequest) error {
upserts := make([]state.SetRequest, 0)
deletes := make([]state.DeleteRequest, 0)
@ -306,11 +307,11 @@ func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error {
}
// best effort, no transacts supported
if err := s.BulkSet(upserts); err != nil {
if err := s.BulkSet(ctx, upserts); err != nil {
return errors.Wrap(err, "error saving records to the database")
}
if err := s.BulkDelete(deletes); err != nil {
if err := s.BulkDelete(ctx, deletes); err != nil {
return errors.Wrap(err, "error deleting records to the database")
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package rethinkdb
import (
"context"
"encoding/json"
"fmt"
"os"
@ -86,12 +87,12 @@ func TestRethinkDBStateStore(t *testing.T) {
d := &testObj{F1: "test", F2: 1, F3: time.Now().UTC()}
k := fmt.Sprintf("ids-%d", time.Now().UnixNano())
if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil {
if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil {
t.Fatalf("error setting data to db: %v", err)
}
// get set data and compare
resp, err := db.Get(&state.GetRequest{Key: k})
resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err)
d2 := testGetTestObj(t, resp)
assert.NotNil(t, d2)
@ -103,12 +104,12 @@ func TestRethinkDBStateStore(t *testing.T) {
d2.F2 = 2
d2.F3 = time.Now().UTC()
tag := fmt.Sprintf("hash-%d", time.Now().UnixNano())
if err = db.Set(&state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil {
if err = db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil {
t.Fatalf("error setting data to db: %v", err)
}
// get updated data and compare
resp2, err := db.Get(&state.GetRequest{Key: k})
resp2, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err)
d3 := testGetTestObj(t, resp2)
assert.NotNil(t, d3)
@ -117,7 +118,7 @@ func TestRethinkDBStateStore(t *testing.T) {
assert.Equal(t, d2.F3.Format(time.RFC3339), d3.F3.Format(time.RFC3339))
// delete data
if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil {
if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil {
t.Fatalf("error on data deletion: %v", err)
}
})
@ -127,19 +128,19 @@ func TestRethinkDBStateStore(t *testing.T) {
d := []byte("test")
k := fmt.Sprintf("idb-%d", time.Now().UnixNano())
if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil {
if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil {
t.Fatalf("error setting data to db: %v", err)
}
// get set data and compare
resp, err := db.Get(&state.GetRequest{Key: k})
resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.Data)
assert.Equal(t, string(d), string(resp.Data))
// delete data
if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil {
if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil {
t.Fatalf("error on data deletion: %v", err)
}
})
@ -177,26 +178,26 @@ func testBulk(t *testing.T, db *RethinkDB, i int) {
}
// bulk set it
if err := db.BulkSet(setList); err != nil {
if err := db.BulkSet(context.TODO(), setList); err != nil {
t.Fatalf("error setting data to db: %v -- run %d", err, i)
}
// check for the data
for _, v := range deleteList {
resp, err := db.Get(&state.GetRequest{Key: v.Key})
resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key})
assert.Nilf(t, err, " -- run %d", i)
assert.NotNil(t, resp)
assert.NotNil(t, resp.Data)
}
// delete data
if err := db.BulkDelete(deleteList); err != nil {
if err := db.BulkDelete(context.TODO(), deleteList); err != nil {
t.Fatalf("error on data deletion: %v -- run %d", err, i)
}
// check for the data NOT being there
for _, v := range deleteList {
resp, err := db.Get(&state.GetRequest{Key: v.Key})
resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key})
assert.Nilf(t, err, " -- run %d", i)
assert.NotNil(t, resp)
assert.Nil(t, resp.Data)
@ -224,7 +225,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) {
for i := 0; i < numOfRecords; i++ {
list[i] = state.SetRequest{Key: fmt.Sprintf(recordIDFormat, i), Value: d}
}
if err := db.BulkSet(list); err != nil {
if err := db.BulkSet(context.TODO(), list); err != nil {
t.Fatalf("error setting multi to db: %v", err)
}
@ -258,19 +259,19 @@ func TestRethinkDBStateStoreMulti(t *testing.T) {
}
// execute multi
if err := db.Multi(req); err != nil {
if err := db.Multi(context.TODO(), req); err != nil {
t.Fatalf("error setting multi to db: %v", err)
}
// the one not deleted should be still there
m1, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)})
m1, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)})
assert.Nil(t, err)
assert.NotNil(t, m1)
assert.NotNil(t, m1.Data)
assert.Equal(t, string(d2), string(m1.Data))
// the one deleted should not
m2, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)})
m2, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)})
assert.Nil(t, err)
assert.NotNil(t, m2)
assert.Nil(t, m2.Data)

View File

@ -14,6 +14,7 @@ limitations under the License.
package sqlserver
import (
"context"
"database/sql"
"encoding/hex"
"encoding/json"
@ -338,7 +339,7 @@ func (s *SQLServer) Features() []state.Feature {
}
// Multi performs multiple updates on a Sql server store.
func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
@ -353,7 +354,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
return err
}
err = s.executeSet(tx, &setReq)
err = s.executeSet(ctx, tx, &setReq)
if err != nil {
tx.Rollback()
return err
@ -366,7 +367,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
return err
}
err = s.executeDelete(tx, &delReq)
err = s.executeDelete(ctx, tx, &delReq)
if err != nil {
tx.Rollback()
return err
@ -410,11 +411,11 @@ func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.Del
}
// Delete removes an entity from the store.
func (s *SQLServer) Delete(req *state.DeleteRequest) error {
return s.executeDelete(s.db, req)
func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error {
return s.executeDelete(ctx, s.db, req)
}
func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error {
func (s *SQLServer) executeDelete(ctx context.Context, db dbExecutor, req *state.DeleteRequest) error {
var err error
var res sql.Result
if req.ETag != nil {
@ -424,9 +425,9 @@ func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error
return state.NewETagError(state.ETagInvalid, err)
}
res, err = db.Exec(s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
res, err = db.ExecContext(ctx, s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
} else {
res, err = db.Exec(s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
res, err = db.ExecContext(ctx, s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
}
// err represents errors thrown by the stored procedure or the database itself
@ -456,13 +457,13 @@ type TvpDeleteTableStringKey struct {
}
// BulkDelete removes multiple entries from the store.
func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error {
func (s *SQLServer) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
err = s.executeBulkDelete(tx, req)
err = s.executeBulkDelete(ctx, tx, req)
if err != nil {
tx.Rollback()
@ -474,7 +475,7 @@ func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error {
return nil
}
func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) error {
func (s *SQLServer) executeBulkDelete(ctx context.Context, db dbExecutor, req []state.DeleteRequest) error {
values := make([]TvpDeleteTableStringKey, len(req))
for i, d := range req {
var etag []byte
@ -493,7 +494,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest)
Value: values,
}
res, err := db.Exec(s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
res, err := db.ExecContext(ctx, s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
if err != nil {
return err
}
@ -513,7 +514,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest)
}
// Get returns an entity from store.
func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key))
if err != nil {
return nil, err
@ -545,21 +546,21 @@ func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// BulkGet performs a bulks get operations.
func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (s *SQLServer) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil
}
// Set adds/updates an entity on store.
func (s *SQLServer) Set(req *state.SetRequest) error {
return s.executeSet(s.db, req)
func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error {
return s.executeSet(ctx, s.db, req)
}
// dbExecutor implements a common functionality implemented by db or tx.
type dbExecutor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error {
var err error
var bytes []byte
bytes, err = utils.Marshal(req.Value, json.Marshal)
@ -578,9 +579,9 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
var res sql.Result
if req.Options.Concurrency == state.FirstWrite {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1))
res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1))
} else {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0))
res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0))
}
if err != nil {
@ -604,14 +605,14 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
}
// BulkSet adds/updates multiple entities on store.
func (s *SQLServer) BulkSet(req []state.SetRequest) error {
func (s *SQLServer) BulkSet(ctx context.Context, req []state.SetRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for i := range req {
err = s.executeSet(tx, &req[i])
err = s.executeSet(ctx, tx, &req[i])
if err != nil {
tx.Rollback()

View File

@ -15,6 +15,7 @@ limitations under the License.
package sqlserver
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -130,7 +131,7 @@ func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string)
}
func assertUserExists(t *testing.T, store *SQLServer, key string) (user, string) {
getRes, err := store.Get(&state.GetRequest{Key: key})
getRes, err := store.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Nil(t, err)
assert.NotNil(t, getRes)
assert.NotNil(t, getRes.Data, "No data was returned")
@ -153,7 +154,7 @@ func assertLoadedUserIsEqual(t *testing.T, store *SQLServer, key string, expecte
}
func assertUserDoesNotExist(t *testing.T, store *SQLServer, key string) {
_, err := store.Get(&state.GetRequest{Key: key})
_, err := store.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Nil(t, err)
}
@ -224,14 +225,14 @@ func testSingleOperations(t *testing.T) {
assertUserDoesNotExist(t, store, john.ID)
// Save and read
err := store.Set(&state.SetRequest{Key: john.ID, Value: john})
err := store.Set(context.TODO(), &state.SetRequest{Key: john.ID, Value: john})
assert.Nil(t, err)
johnV1, etagFromInsert := assertLoadedUserIsEqual(t, store, john.ID, john)
// Update with ETAG
waterJohn := johnV1
waterJohn.FavoriteBeverage = "Water"
err = store.Set(&state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert})
err = store.Set(context.TODO(), &state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert})
assert.Nil(t, err)
// Get updated
@ -240,7 +241,7 @@ func testSingleOperations(t *testing.T) {
// Update without ETAG
noEtagJohn := johnV2
noEtagJohn.FavoriteBeverage = "No Etag John"
err = store.Set(&state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn})
err = store.Set(context.TODO(), &state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn})
assert.Nil(t, err)
// 7. Get updated
@ -249,17 +250,17 @@ func testSingleOperations(t *testing.T) {
// 8. Update with invalid ETAG should fail
failedJohn := johnV3
failedJohn.FavoriteBeverage = "Will not work"
err = store.Set(&state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert})
err = store.Set(context.TODO(), &state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert})
assert.NotNil(t, err)
_, etag := assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3)
// 9. Delete with invalid ETAG should fail
err = store.Delete(&state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag})
err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag})
assert.NotNil(t, err)
assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3)
// 10. Delete with valid ETAG
err = store.Delete(&state.DeleteRequest{Key: johnV2.ID, ETag: &etag})
err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV2.ID, ETag: &etag})
assert.Nil(t, err)
assertUserDoesNotExist(t, store, johnV2.ID)
@ -273,7 +274,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) {
u := user{uuid.New().String(), "John", "Coffee"}
invEtag := invalidEtag
err := store.Set(&state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag})
err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag})
assert.NotNil(t, err)
}
@ -281,7 +282,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) {
func testIndexedProperties(t *testing.T) {
store := getTestStore(t, `[{ "column":"FavoriteBeverage", "property":"FavoriteBeverage", "type":"nvarchar(100)"}, { "column":"PetsCount", "property":"PetsCount", "type": "INTEGER"}]`)
err := store.BulkSet([]state.SetRequest{
err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: "1", Value: userWithPets{user{"1", "John", "Coffee"}, 3}},
{Key: "2", Value: userWithPets{user{"2", "Laura", "Water"}, 1}},
{Key: "3", Value: userWithPets{user{"3", "Carl", "Beer"}, 0}},
@ -343,7 +344,7 @@ func testMultiOperations(t *testing.T) {
bulkSet[i] = state.SetRequest{Key: u.ID, Value: u}
}
err := store.BulkSet(bulkSet)
err := store.BulkSet(context.TODO(), bulkSet)
assert.Nil(t, err)
assertUserCountIsEqualTo(t, store, len(initialUsers))
@ -363,7 +364,7 @@ func testMultiOperations(t *testing.T) {
modified := original.user
modified.FavoriteBeverage = beverageTea
localErr := store.Multi(&state.TransactionalStateRequest{
localErr := store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}},
@ -386,7 +387,7 @@ func testMultiOperations(t *testing.T) {
modified := toModify.user
modified.FavoriteBeverage = beverageTea
err = store.Multi(&state.TransactionalStateRequest{
err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}},
@ -410,7 +411,7 @@ func testMultiOperations(t *testing.T) {
modified := toModify.user
modified.FavoriteBeverage = beverageTea
err = store.Multi(&state.TransactionalStateRequest{
err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}},
@ -431,7 +432,7 @@ func testMultiOperations(t *testing.T) {
toInsert := user{keyGen.NextKey(), "Wont-be-inserted", "Beer"}
invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{
err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: toInsert.ID, Value: toInsert}},
@ -452,7 +453,7 @@ func testMultiOperations(t *testing.T) {
modified.FavoriteBeverage = beverageTea
invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{
err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}},
@ -472,7 +473,7 @@ func testMultiOperations(t *testing.T) {
modified.FavoriteBeverage = beverageTea
invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{
err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &invEtag}},
@ -520,7 +521,7 @@ func testBulkSet(t *testing.T) {
sets[i] = state.SetRequest{Key: u.ID, Value: u}
}
err := store.BulkSet(sets)
err := store.BulkSet(context.TODO(), sets)
assert.Nil(t, err)
totalUsers = len(sets)
assertUserCountIsEqualTo(t, store, totalUsers)
@ -532,7 +533,7 @@ func testBulkSet(t *testing.T) {
modified.FavoriteBeverage = beverageTea
toInsert := user{keyGen.NextKey(), "Maria", "Wine"}
err := store.BulkSet([]state.SetRequest{
err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: modified.ID, Value: modified, ETag: &toModifyETag},
{Key: toInsert.ID, Value: toInsert},
})
@ -551,7 +552,7 @@ func testBulkSet(t *testing.T) {
modified.FavoriteBeverage = beverageTea
toInsert := user{keyGen.NextKey(), "Tony", "Milk"}
err := store.BulkSet([]state.SetRequest{
err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: modified.ID, Value: modified},
{Key: toInsert.ID, Value: toInsert},
})
@ -578,7 +579,7 @@ func testBulkSet(t *testing.T) {
{Key: modified.ID, Value: modified, ETag: &invEtag},
}
err := store.BulkSet(sets)
err := store.BulkSet(context.TODO(), sets)
assert.NotNil(t, err)
assertUserCountIsEqualTo(t, store, totalUsers)
assertUserDoesNotExist(t, store, toInsert1.ID)
@ -621,7 +622,7 @@ func testBulkDelete(t *testing.T) {
for i, u := range initialUsers {
sets[i] = state.SetRequest{Key: u.ID, Value: u}
}
err := store.BulkSet(sets)
err := store.BulkSet(context.TODO(), sets)
assert.Nil(t, err)
totalUsers := len(initialUsers)
assertUserCountIsEqualTo(t, store, totalUsers)
@ -631,7 +632,7 @@ func testBulkDelete(t *testing.T) {
t.Run("Delete 2 items without etag should work", func(t *testing.T) {
deleted1 := initialUsers[userIndex].ID
deleted2 := initialUsers[userIndex+1].ID
err := store.BulkDelete([]state.DeleteRequest{
err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1},
{Key: deleted2},
})
@ -648,7 +649,7 @@ func testBulkDelete(t *testing.T) {
deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID)
deleted2, deleted2Etag := assertUserExists(t, store, initialUsers[userIndex+1].ID)
err := store.BulkDelete([]state.DeleteRequest{
err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID, ETag: &deleted2Etag},
})
@ -665,7 +666,7 @@ func testBulkDelete(t *testing.T) {
deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID)
deleted2 := initialUsers[userIndex+1]
err := store.BulkDelete([]state.DeleteRequest{
err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID},
})
@ -683,7 +684,7 @@ func testBulkDelete(t *testing.T) {
deleted2 := initialUsers[userIndex+1]
invEtag := invalidEtag
err := store.BulkDelete([]state.DeleteRequest{
err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID, ETag: &invEtag},
})
@ -703,7 +704,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) {
store := getTestStore(t, "")
u := user{"1", "John", "Coffee"}
err := store.Set(&state.SetRequest{Key: u.ID, Value: u})
err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u})
assert.Nil(t, err)
var originalInsertTime time.Time
@ -725,7 +726,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) {
modified := u
modified.FavoriteBeverage = beverageTea
err = store.Set(&state.SetRequest{Key: modified.ID, Value: modified})
err = store.Set(context.TODO(), &state.SetRequest{Key: modified.ID, Value: modified})
assert.Nil(t, err)
assertDBQuery(t, store, getUserTsql, func(t *testing.T, rows *sql.Rows) {
assert.True(t, rows.Next())
@ -749,7 +750,7 @@ func testConcurrentSets(t *testing.T) {
store := getTestStore(t, "")
u := user{"1", "John", "Coffee"}
err := store.Set(&state.SetRequest{Key: u.ID, Value: u})
err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u})
assert.Nil(t, err)
_, etag := assertLoadedUserIsEqual(t, store, u.ID, u)
@ -766,7 +767,7 @@ func testConcurrentSets(t *testing.T) {
defer wc.Done()
modified := user{"1", "John", beverageTea}
err := store.Set(&state.SetRequest{Key: id, Value: modified, ETag: &etag})
err := store.Set(context.TODO(), &state.SetRequest{Key: id, Value: modified, ETag: &etag})
if err != nil {
atomic.AddInt32(&totalErrors, 1)
} else {

View File

@ -14,6 +14,7 @@ limitations under the License.
package state
import (
"context"
"fmt"
"github.com/dapr/components-contrib/health"
@ -24,9 +25,9 @@ type Store interface {
BulkStore
Init(metadata Metadata) error
Features() []Feature
Delete(req *DeleteRequest) error
Get(req *GetRequest) (*GetResponse, error)
Set(req *SetRequest) error
Delete(ctx context.Context, req *DeleteRequest) error
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)
Set(ctx context.Context, req *SetRequest) error
}
func Ping(store Store) error {
@ -40,9 +41,9 @@ func Ping(store Store) error {
// BulkStore is an interface to perform bulk operations on store.
type BulkStore interface {
BulkDelete(req []DeleteRequest) error
BulkGet(req []GetRequest) (bool, []BulkGetResponse, error)
BulkSet(req []SetRequest) error
BulkDelete(ctx context.Context, req []DeleteRequest) error
BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error)
BulkSet(ctx context.Context, req []SetRequest) error
}
// DefaultBulkStore is a default implementation of BulkStore.
@ -64,16 +65,16 @@ func (b *DefaultBulkStore) Features() []Feature {
}
// BulkGet performs a bulks get operations.
func (b *DefaultBulkStore) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) {
func (b *DefaultBulkStore) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) {
// by default, the store doesn't support bulk get
// return false so daprd will fallback to call get() method one by one
return false, nil, nil
}
// BulkSet performs a bulks save operation.
func (b *DefaultBulkStore) BulkSet(req []SetRequest) error {
func (b *DefaultBulkStore) BulkSet(ctx context.Context, req []SetRequest) error {
for i := range req {
err := b.s.Set(&req[i])
err := b.s.Set(ctx, &req[i])
if err != nil {
return err
}
@ -83,9 +84,9 @@ func (b *DefaultBulkStore) BulkSet(req []SetRequest) error {
}
// BulkDelete performs a bulk delete operation.
func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error {
func (b *DefaultBulkStore) BulkDelete(ctx context.Context, req []DeleteRequest) error {
for i := range req {
err := b.s.Delete(&req[i])
err := b.s.Delete(ctx, &req[i])
if err != nil {
return err
}
@ -96,5 +97,5 @@ func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error {
// Querier is an interface to execute queries.
type Querier interface {
Query(req *QueryRequest) (*QueryResponse, error)
Query(ctx context.Context, req *QueryRequest) (*QueryResponse, error)
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package state
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -26,22 +27,22 @@ func TestStore_withDefaultBulkImpl(t *testing.T) {
require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{})
store.Set(&SetRequest{})
store.Delete(&DeleteRequest{})
store.Get(context.TODO(), &GetRequest{})
store.Set(context.TODO(), &SetRequest{})
store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount)
bulkGet, responses, err := store.BulkGet([]GetRequest{{}, {}, {}})
bulkGet, responses, err := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, false, bulkGet)
require.Equal(t, 0, len(responses))
require.NoError(t, err)
require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}})
store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 3+4, s.count)
require.Equal(t, 0, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}})
store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 3+4+5, s.count)
require.Equal(t, 0, s.bulkCount)
}
@ -52,20 +53,20 @@ func TestStore_withCustomisedBulkImpl_notSupportBulkGet(t *testing.T) {
require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{})
store.Set(&SetRequest{})
store.Delete(&DeleteRequest{})
store.Get(context.TODO(), &GetRequest{})
store.Set(context.TODO(), &SetRequest{})
store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount)
bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}})
bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, false, bulkGet)
require.Equal(t, 6, s.count)
require.Equal(t, 0, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}})
store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 6, s.count)
require.Equal(t, 1, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}})
store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 6, s.count)
require.Equal(t, 2, s.bulkCount)
}
@ -76,20 +77,20 @@ func TestStore_withCustomisedBulkImpl_supportBulkGet(t *testing.T) {
require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{})
store.Set(&SetRequest{})
store.Delete(&DeleteRequest{})
store.Get(context.TODO(), &GetRequest{})
store.Set(context.TODO(), &SetRequest{})
store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount)
bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}})
bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, true, bulkGet)
require.Equal(t, 3, s.count)
require.Equal(t, 1, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}})
store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 3, s.count)
require.Equal(t, 2, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}})
store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 3, s.count)
require.Equal(t, 3, s.bulkCount)
}
@ -110,19 +111,19 @@ func (s *Store1) Init(metadata Metadata) error {
return nil
}
func (s *Store1) Delete(req *DeleteRequest) error {
func (s *Store1) Delete(ctx context.Context, req *DeleteRequest) error {
s.count++
return nil
}
func (s *Store1) Get(req *GetRequest) (*GetResponse, error) {
func (s *Store1) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) {
s.count++
return &GetResponse{}, nil
}
func (s *Store1) Set(req *SetRequest) error {
func (s *Store1) Set(ctx context.Context, req *SetRequest) error {
s.count++
return nil
@ -145,25 +146,25 @@ func (s *Store2) Features() []Feature {
return nil
}
func (s *Store2) Delete(req *DeleteRequest) error {
func (s *Store2) Delete(ctx context.Context, req *DeleteRequest) error {
s.count++
return nil
}
func (s *Store2) Get(req *GetRequest) (*GetResponse, error) {
func (s *Store2) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) {
s.count++
return &GetResponse{}, nil
}
func (s *Store2) Set(req *SetRequest) error {
func (s *Store2) Set(ctx context.Context, req *SetRequest) error {
s.count++
return nil
}
func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) {
func (s *Store2) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) {
if s.supportBulkGet {
s.bulkCount++
@ -175,13 +176,13 @@ func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) {
return false, nil, nil
}
func (s *Store2) BulkSet(req []SetRequest) error {
func (s *Store2) BulkSet(ctx context.Context, req []SetRequest) error {
s.bulkCount++
return nil
}
func (s *Store2) BulkDelete(req []DeleteRequest) error {
func (s *Store2) BulkDelete(ctx context.Context, req []DeleteRequest) error {
s.bulkCount++
return nil

View File

@ -14,6 +14,7 @@ limitations under the License.
package zookeeper
import (
"context"
"errors"
"path"
"strconv"
@ -161,7 +162,7 @@ func (s *StateStore) Features() []state.Feature {
}
// Get retrieves state from Zookeeper with a key.
func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
func (s *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
value, stat, err := s.conn.Get(s.prefixedKey(req.Key))
if err != nil {
if errors.Is(err, zk.ErrNoNode) {
@ -178,19 +179,19 @@ func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
// BulkGet performs a bulks get operations.
func (s *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
func (s *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with Multi for performance
return false, nil, nil
}
// Delete performs a delete operation.
func (s *StateStore) Delete(req *state.DeleteRequest) error {
func (s *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r, err := s.newDeleteRequest(req)
if err != nil {
return err
}
return state.DeleteWithOptions(func(req *state.DeleteRequest) error {
return state.DeleteWithOptions(ctx, func(ctx context.Context, req *state.DeleteRequest) error {
err := s.conn.Delete(r.Path, r.Version)
if errors.Is(err, zk.ErrNoNode) {
return nil
@ -209,7 +210,7 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
}
// BulkDelete performs a bulk delete operation.
func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error {
func (s *StateStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error {
ops := make([]interface{}, 0, len(reqs))
for i := range reqs {
@ -236,13 +237,13 @@ func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error {
}
// Set saves state into Zookeeper.
func (s *StateStore) Set(req *state.SetRequest) error {
func (s *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r, err := s.newSetDataRequest(req)
if err != nil {
return err
}
return state.SetWithOptions(func(req *state.SetRequest) error {
return state.SetWithOptions(ctx, func(ctx context.Context, req *state.SetRequest) error {
_, err = s.conn.Set(r.Path, r.Data, r.Version)
if errors.Is(err, zk.ErrNoNode) {
@ -262,7 +263,7 @@ func (s *StateStore) Set(req *state.SetRequest) error {
}
// BulkSet performs a bulks save operation.
func (s *StateStore) BulkSet(reqs []state.SetRequest) error {
func (s *StateStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error {
ops := make([]interface{}, 0, len(reqs))
for i := range reqs {

View File

@ -14,6 +14,7 @@ limitations under the License.
package zookeeper
import (
"context"
"fmt"
"testing"
"time"
@ -80,7 +81,7 @@ func TestGet(t *testing.T) {
t.Run("With key exists", func(t *testing.T) {
conn.EXPECT().Get("foo").Return([]byte("bar"), &zk.Stat{Version: 123}, nil).Times(1)
res, err := s.Get(&state.GetRequest{Key: "foo"})
res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"})
assert.NotNil(t, res, "Key must be exists")
assert.Equal(t, "bar", string(res.Data), "Value must be equals")
assert.Equal(t, ptr.String("123"), res.ETag, "ETag must be equals")
@ -90,7 +91,7 @@ func TestGet(t *testing.T) {
t.Run("With key non-exists", func(t *testing.T) {
conn.EXPECT().Get("foo").Return(nil, nil, zk.ErrNoNode).Times(1)
res, err := s.Get(&state.GetRequest{Key: "foo"})
res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"})
assert.Equal(t, &state.GetResponse{}, res, "Response must be empty")
assert.NoError(t, err, "Non-existent key must not be treated as error")
})
@ -108,21 +109,21 @@ func TestDelete(t *testing.T) {
t.Run("With key", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.NoError(t, err, "Key must be exists")
})
t.Run("With key and version", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(123)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo", ETag: &etag})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo", ETag: &etag})
assert.NoError(t, err, "Key must be exists")
})
t.Run("With key and concurrency", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{
err := s.Delete(context.TODO(), &state.DeleteRequest{
Key: "foo",
ETag: &etag,
Options: state.DeleteStateOption{Concurrency: state.LastWrite},
@ -133,14 +134,14 @@ func TestDelete(t *testing.T) {
t.Run("With delete error", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrUnknown).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.EqualError(t, err, "zk: unknown error")
})
t.Run("With delete and ignore NoNode error", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrNoNode).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"})
err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.NoError(t, err, "Delete must be successful")
})
}
@ -159,7 +160,7 @@ func TestBulkDelete(t *testing.T) {
&zk.DeleteRequest{Path: "bar", Version: int32(anyVersion)},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.NoError(t, err, "Key must be exists")
})
@ -171,7 +172,7 @@ func TestBulkDelete(t *testing.T) {
{Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth},
}, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.Equal(t, err.(*multierror.Error).Errors, []error{zk.ErrUnknown, zk.ErrNoAuth})
})
t.Run("With keys and ignore NoNode error", func(t *testing.T) {
@ -182,7 +183,7 @@ func TestBulkDelete(t *testing.T) {
{Error: zk.ErrNoNode}, {},
}, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.NoError(t, err, "Key must be exists")
})
}
@ -201,19 +202,19 @@ func TestSet(t *testing.T) {
t.Run("With key", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"})
err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.NoError(t, err, "Key must be set")
})
t.Run("With key and version", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(123)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar", ETag: &etag})
err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar", ETag: &etag})
assert.NoError(t, err, "Key must be set")
})
t.Run("With key and concurrency", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{
err := s.Set(context.TODO(), &state.SetRequest{
Key: "foo",
Value: "bar",
ETag: &etag,
@ -225,14 +226,14 @@ func TestSet(t *testing.T) {
t.Run("With error", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrUnknown).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"})
err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.EqualError(t, err, "zk: unknown error")
})
t.Run("With NoNode error and retry", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrNoNode).Times(1)
conn.EXPECT().Create("foo", []byte("\"bar\""), int32(0), nil).Return("/foo", nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"})
err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.NoError(t, err, "Key must be create")
})
}
@ -251,7 +252,7 @@ func TestBulkSet(t *testing.T) {
&zk.SetDataRequest{Path: "bar", Data: []byte("\"foo\""), Version: int32(anyVersion)},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkSet([]state.SetRequest{
err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"},
})
@ -266,7 +267,7 @@ func TestBulkSet(t *testing.T) {
{Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth},
}, nil).Times(1)
err := s.BulkSet([]state.SetRequest{
err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"},
})
@ -283,7 +284,7 @@ func TestBulkSet(t *testing.T) {
&zk.CreateRequest{Path: "foo", Data: []byte("\"bar\"")},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkSet([]state.SetRequest{
err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"},
})

View File

@ -14,6 +14,7 @@ limitations under the License.
package state
import (
"context"
"encoding/json"
"fmt"
"sort"
@ -251,7 +252,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
}
err := statestore.Set(req)
err := statestore.Set(context.TODO(), req)
assert.Nil(t, err)
}
}
@ -269,7 +270,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
}
res, err := statestore.Get(req)
res, err := statestore.Get(context.TODO(), req)
assert.Nil(t, err)
assertEquals(t, scenario.value, res)
}
@ -290,7 +291,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
metadata.ContentType: contenttype.JSONContentType,
metadata.QueryIndexName: "qIndx",
}
resp, err := querier.Query(&req)
resp, err := querier.Query(context.TODO(), &req)
assert.NoError(t, err)
assert.Equal(t, len(scenario.results), len(resp.Results))
for i := range scenario.results {
@ -318,11 +319,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
}
err := statestore.Delete(req)
err := statestore.Delete(context.TODO(), req)
assert.Nil(t, err, "no error expected while deleting %s", scenario.key)
t.Logf("Checking value absence for %s", scenario.key)
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key,
})
assert.Nil(t, err, "no error expected while checking for absence for %s", scenario.key)
@ -344,14 +345,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
})
}
}
err := statestore.BulkSet(bulk)
err := statestore.BulkSet(context.TODO(), bulk)
assert.Nil(t, err)
for _, scenario := range scenarios {
if scenario.bulkOnly {
t.Logf("Checking value presence for %s", scenario.key)
// Data should have been inserted at this point
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key,
})
assert.Nil(t, err)
@ -372,12 +373,12 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
})
}
}
err := statestore.BulkDelete(bulk)
err := statestore.BulkDelete(context.TODO(), bulk)
assert.Nil(t, err)
for _, req := range bulk {
t.Logf("Checking value absence for %s", req.Key)
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: req.Key,
})
assert.Nil(t, err)
@ -443,7 +444,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if scenario.transactionGroup == transactionGroup {
t.Logf("Checking value presence for %s", scenario.key)
// Data should have been inserted at this point
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key,
// For CosmosDB
Metadata: map[string]string{
@ -457,7 +458,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if scenario.toBeDeleted && (scenario.transactionGroup == transactionGroup-1) {
t.Logf("Checking value absence for %s", scenario.key)
// Data should have been deleted at this point
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key,
// For CosmosDB
Metadata: map[string]string{
@ -487,7 +488,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
}
// prerequisite: key1 should be present
err := statestore.Set(&state.SetRequest{
err := statestore.Set(context.TODO(), &state.SetRequest{
Key: firstKey,
Value: firstValue,
Metadata: partitionMetadata,
@ -495,14 +496,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
assert.NoError(t, err, "set request should be successful")
// prerequisite: key2 should not be present
err = statestore.Delete(&state.DeleteRequest{
err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: secondKey,
Metadata: partitionMetadata,
})
assert.NoError(t, err, "delete request should be successful")
// prerequisite: key3 should not be present
err = statestore.Delete(&state.DeleteRequest{
err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: thirdKey,
Metadata: partitionMetadata,
})
@ -558,7 +559,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
// Assert
for k, v := range expected {
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: k,
Metadata: partitionMetadata,
})
@ -585,20 +586,20 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.True(t, state.FeatureETag.IsPresent(features))
// Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{
err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey,
})
require.Nil(t, err)
// Set an object.
err = statestore.Set(&state.SetRequest{
err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey,
Value: firstValue,
})
require.Nil(t, err)
// Validate the set.
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey,
})
@ -607,7 +608,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
etag := res.ETag
// Try and update with wrong ETag, expect failure.
err = statestore.Set(&state.SetRequest{
err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey,
Value: secondValue,
ETag: &fakeEtag,
@ -615,7 +616,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.NotNil(t, err)
// Try and update with corect ETag, expect success.
err = statestore.Set(&state.SetRequest{
err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey,
Value: secondValue,
ETag: etag,
@ -623,7 +624,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.Nil(t, err)
// Validate the set.
res, err = statestore.Get(&state.GetRequest{
res, err = statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey,
})
require.Nil(t, err)
@ -632,14 +633,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
etag = res.ETag
// Try and delete with wrong ETag, expect failure.
err = statestore.Delete(&state.DeleteRequest{
err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey,
ETag: &fakeEtag,
})
require.NotNil(t, err)
// Try and delete with correct ETag, expect success.
err = statestore.Delete(&state.DeleteRequest{
err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey,
ETag: etag,
})
@ -698,23 +699,23 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
for i, requestSet := range requestSets {
t.Run(fmt.Sprintf("request set %d", i), func(t *testing.T) {
// Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{
err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey,
})
require.Nil(t, err)
err = statestore.Set(requestSet[0])
err = statestore.Set(context.TODO(), requestSet[0])
require.Nil(t, err)
// Validate the set.
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey,
})
require.Nil(t, err)
assertEquals(t, firstValue, res)
// Second write expect fail
err = statestore.Set(requestSet[1])
err = statestore.Set(context.TODO(), requestSet[1])
require.NotNil(t, err)
})
}
@ -731,16 +732,16 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
}
// Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{
err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey,
})
require.Nil(t, err)
err = statestore.Set(request)
err = statestore.Set(context.TODO(), request)
require.Nil(t, err)
// Validate the set.
res, err := statestore.Get(&state.GetRequest{
res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey,
})
require.Nil(t, err)
@ -757,11 +758,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
Consistency: state.Strong,
},
}
err = statestore.Set(request)
err = statestore.Set(context.TODO(), request)
require.Nil(t, err)
// Validate the set.
res, err = statestore.Get(&state.GetRequest{
res, err = statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey,
})
require.Nil(t, err)
@ -771,7 +772,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
request.ETag = etag
// Second write expect fail
err = statestore.Set(request)
err = statestore.Set(context.TODO(), request)
require.NotNil(t, err)
})
}