feature: add context to state API

Signed-off-by: 1046102779 <seachen@tencent.com>
This commit is contained in:
1046102779 2022-10-21 17:25:14 +08:00
parent 0326f139c5
commit 5a367b401a
58 changed files with 846 additions and 822 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
package aerospike package aerospike
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -110,7 +111,7 @@ func (aspike *Aerospike) Features() []state.Feature {
} }
// Set stores value for a key to Aerospike. It honors ETag (for concurrency) and consistency settings. // Set stores value for a key to Aerospike. It honors ETag (for concurrency) and consistency settings.
func (aspike *Aerospike) Set(req *state.SetRequest) error { func (aspike *Aerospike) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -162,7 +163,7 @@ func (aspike *Aerospike) Set(req *state.SetRequest) error {
} }
// Get retrieves state from Aerospike with a key. // Get retrieves state from Aerospike with a key.
func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) { func (aspike *Aerospike) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
asKey, err := as.NewKey(aspike.namespace, aspike.set, req.Key) asKey, err := as.NewKey(aspike.namespace, aspike.set, req.Key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -196,7 +197,7 @@ func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error)
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (aspike *Aerospike) Delete(req *state.DeleteRequest) error { func (aspike *Aerospike) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err

View File

@ -14,6 +14,7 @@ limitations under the License.
package tablestore package tablestore
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/agrea/ptr" "github.com/agrea/ptr"
@ -68,7 +69,7 @@ func (s *AliCloudTableStore) Features() []state.Feature {
return s.features return s.features
} }
func (s *AliCloudTableStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (s *AliCloudTableStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
criteria := &tablestore.SingleRowQueryCriteria{ criteria := &tablestore.SingleRowQueryCriteria{
PrimaryKey: s.primaryKey(req.Key), PrimaryKey: s.primaryKey(req.Key),
TableName: s.metadata.TableName, TableName: s.metadata.TableName,
@ -103,7 +104,7 @@ func (s *AliCloudTableStore) getResp(columns []*tablestore.AttributeColumn) *sta
return getResp return getResp
} }
func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (s *AliCloudTableStore) BulkGet(ctx context.Context, reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// "len == 0": empty request, directly return empty response // "len == 0": empty request, directly return empty response
if len(reqs) == 0 { if len(reqs) == 0 {
return true, []state.BulkGetResponse{}, nil return true, []state.BulkGetResponse{}, nil
@ -139,7 +140,7 @@ func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.Bul
return true, responseList, nil return true, responseList, nil
} }
func (s *AliCloudTableStore) Set(req *state.SetRequest) error { func (s *AliCloudTableStore) Set(ctx context.Context, req *state.SetRequest) error {
change := s.updateRowChange(req) change := s.updateRowChange(req)
request := &tablestore.UpdateRowRequest{ request := &tablestore.UpdateRowRequest{
@ -183,7 +184,7 @@ func unmarshal(val interface{}) []byte {
return []byte(output) return []byte(output)
} }
func (s *AliCloudTableStore) Delete(req *state.DeleteRequest) error { func (s *AliCloudTableStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
change := s.deleteRowChange(req) change := s.deleteRowChange(req)
deleteRowReq := &tablestore.DeleteRowRequest{ deleteRowReq := &tablestore.DeleteRowRequest{
@ -205,15 +206,15 @@ func (s *AliCloudTableStore) deleteRowChange(req *state.DeleteRequest) *tablesto
return change return change
} }
func (s *AliCloudTableStore) BulkSet(reqs []state.SetRequest) error { func (s *AliCloudTableStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error {
return s.batchWrite(reqs, nil) return s.batchWrite(ctx, reqs, nil)
} }
func (s *AliCloudTableStore) BulkDelete(reqs []state.DeleteRequest) error { func (s *AliCloudTableStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error {
return s.batchWrite(nil, reqs) return s.batchWrite(ctx, nil, reqs)
} }
func (s *AliCloudTableStore) batchWrite(setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error { func (s *AliCloudTableStore) batchWrite(ctx context.Context, setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error {
bathReq := &tablestore.BatchWriteRowRequest{ bathReq := &tablestore.BatchWriteRowRequest{
IsAtomic: true, IsAtomic: true,
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package tablestore package tablestore
import ( import (
"context"
"testing" "testing"
"github.com/agrea/ptr" "github.com/agrea/ptr"
@ -63,7 +64,7 @@ func TestReadAndWrite(t *testing.T) {
Value: "value of key", Value: "value of key",
ETag: ptr.String("the etag"), ETag: ptr.String("the etag"),
} }
err := store.Set(setReq) err := store.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@ -71,7 +72,7 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{ getReq := &state.GetRequest{
Key: "theFirstKey", Key: "theFirstKey",
} }
resp, err := store.Get(getReq) resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "value of key", string(resp.Data)) assert.Equal(t, "value of key", string(resp.Data))
@ -83,7 +84,7 @@ func TestReadAndWrite(t *testing.T) {
Value: "1234", Value: "1234",
ETag: ptr.String("the etag"), ETag: ptr.String("the etag"),
} }
err := store.Set(setReq) err := store.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@ -91,14 +92,14 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{ getReq := &state.GetRequest{
Key: "theSecondKey", Key: "theSecondKey",
} }
resp, err := store.Get(getReq) resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "1234", string(resp.Data)) assert.Equal(t, "1234", string(resp.Data))
}) })
t.Run("test BulkSet", func(t *testing.T) { t.Run("test BulkSet", func(t *testing.T) {
err := store.BulkSet([]state.SetRequest{{ err := store.BulkSet(context.TODO(), []state.SetRequest{{
Key: "theFirstKey", Key: "theFirstKey",
Value: "666", Value: "666",
}, { }, {
@ -110,7 +111,7 @@ func TestReadAndWrite(t *testing.T) {
}) })
t.Run("test BulkGet", func(t *testing.T) { t.Run("test BulkGet", func(t *testing.T) {
_, resp, err := store.BulkGet([]state.GetRequest{{ _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey", Key: "theFirstKey",
}, { }, {
Key: "theSecondKey", Key: "theSecondKey",
@ -126,12 +127,12 @@ func TestReadAndWrite(t *testing.T) {
req := &state.DeleteRequest{ req := &state.DeleteRequest{
Key: "theFirstKey", Key: "theFirstKey",
} }
err := store.Delete(req) err := store.Delete(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("test BulkGet2", func(t *testing.T) { t.Run("test BulkGet2", func(t *testing.T) {
_, resp, err := store.BulkGet([]state.GetRequest{{ _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey", Key: "theFirstKey",
}, { }, {
Key: "theSecondKey", Key: "theSecondKey",

View File

@ -14,6 +14,7 @@ limitations under the License.
package dynamodb package dynamodb
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
@ -79,7 +80,7 @@ func (d *StateStore) Features() []state.Feature {
} }
// Get retrieves a dynamoDB item. // Get retrieves a dynamoDB item.
func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
input := &dynamodb.GetItemInput{ input := &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong), ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong),
TableName: aws.String(d.table), TableName: aws.String(d.table),
@ -90,7 +91,7 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, },
} }
result, err := d.client.GetItem(input) result, err := d.client.GetItemWithContext(ctx, input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -133,13 +134,13 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// BulkGet performs a bulk get operations. // BulkGet performs a bulk get operations.
func (d *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (d *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with dynamodb.BatchGetItem for performance // TODO: replace with dynamodb.BatchGetItem for performance
return false, nil, nil return false, nil, nil
} }
// Set saves a dynamoDB item. // Set saves a dynamoDB item.
func (d *StateStore) Set(req *state.SetRequest) error { func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
item, err := d.getItemFromReq(req) item, err := d.getItemFromReq(req)
if err != nil { if err != nil {
return err return err
@ -165,7 +166,7 @@ func (d *StateStore) Set(req *state.SetRequest) error {
input.ConditionExpression = &condExpr input.ConditionExpression = &condExpr
} }
_, err = d.client.PutItem(input) _, err = d.client.PutItemWithContext(ctx, input)
if err != nil && haveEtag { if err != nil && haveEtag {
switch cErr := err.(type) { switch cErr := err.(type) {
case *dynamodb.ConditionalCheckFailedException: case *dynamodb.ConditionalCheckFailedException:
@ -177,11 +178,11 @@ func (d *StateStore) Set(req *state.SetRequest) error {
} }
// BulkSet performs a bulk set operation. // BulkSet performs a bulk set operation.
func (d *StateStore) BulkSet(req []state.SetRequest) error { func (d *StateStore) BulkSet(ctx context.Context, req []state.SetRequest) error {
writeRequests := []*dynamodb.WriteRequest{} writeRequests := []*dynamodb.WriteRequest{}
if len(req) == 1 { if len(req) == 1 {
return d.Set(&req[0]) return d.Set(ctx, &req[0])
} }
for _, r := range req { for _, r := range req {
@ -210,7 +211,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error {
requestItems := map[string][]*dynamodb.WriteRequest{} requestItems := map[string][]*dynamodb.WriteRequest{}
requestItems[d.table] = writeRequests requestItems[d.table] = writeRequests
_, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ _, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: requestItems, RequestItems: requestItems,
}) })
@ -218,7 +219,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (d *StateStore) Delete(req *state.DeleteRequest) error { func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
input := &dynamodb.DeleteItemInput{ input := &dynamodb.DeleteItemInput{
Key: map[string]*dynamodb.AttributeValue{ Key: map[string]*dynamodb.AttributeValue{
"key": { "key": {
@ -238,7 +239,7 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error {
input.ExpressionAttributeValues = exprAttrValues input.ExpressionAttributeValues = exprAttrValues
} }
_, err := d.client.DeleteItem(input) _, err := d.client.DeleteItemWithContext(ctx, input)
if err != nil { if err != nil {
switch cErr := err.(type) { switch cErr := err.(type) {
case *dynamodb.ConditionalCheckFailedException: case *dynamodb.ConditionalCheckFailedException:
@ -250,11 +251,11 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error {
} }
// BulkDelete performs a bulk delete operation. // BulkDelete performs a bulk delete operation.
func (d *StateStore) BulkDelete(req []state.DeleteRequest) error { func (d *StateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
writeRequests := []*dynamodb.WriteRequest{} writeRequests := []*dynamodb.WriteRequest{}
if len(req) == 1 { if len(req) == 1 {
return d.Delete(&req[0]) return d.Delete(ctx, &req[0])
} }
for _, r := range req { for _, r := range req {
@ -277,7 +278,7 @@ func (d *StateStore) BulkDelete(req []state.DeleteRequest) error {
requestItems := map[string][]*dynamodb.WriteRequest{} requestItems := map[string][]*dynamodb.WriteRequest{}
requestItems[d.table] = writeRequests requestItems[d.table] = writeRequests
_, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ _, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: requestItems, RequestItems: requestItems,
}) })

View File

@ -15,12 +15,14 @@ limitations under the License.
package dynamodb package dynamodb
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
@ -31,10 +33,10 @@ import (
) )
type mockedDynamoDB struct { type mockedDynamoDB struct {
GetItemFn func(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error)
PutItemFn func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error)
DeleteItemFn func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error)
BatchWriteItemFn func(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error)
dynamodbiface.DynamoDBAPI dynamodbiface.DynamoDBAPI
} }
@ -44,20 +46,20 @@ type DynamoDBItem struct {
TestAttributeName int64 `json:"testAttributeName"` TestAttributeName int64 `json:"testAttributeName"`
} }
func (m *mockedDynamoDB) GetItem(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
return m.GetItemFn(input) return m.GetItemWithContextFn(ctx, input, op...)
} }
func (m *mockedDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) {
return m.PutItemFn(input) return m.PutItemWithContextFn(ctx, input, op...)
} }
func (m *mockedDynamoDB) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) {
return m.DeleteItemFn(input) return m.DeleteItemWithContextFn(ctx, input, op...)
} }
func (m *mockedDynamoDB) BatchWriteItem(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) { func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) {
return m.BatchWriteItemFn(input) return m.BatchWriteItemWithContextFn(ctx, input, op...)
} }
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
@ -99,7 +101,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item", func(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{ return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{ Item: map[string]*dynamodb.AttributeValue{
"key": { "key": {
@ -123,7 +125,7 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, []byte("some value"), out.Data)
assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.Equal(t, "1bdead4badc0ffee", *out.ETag)
@ -131,7 +133,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{ return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{ Item: map[string]*dynamodb.AttributeValue{
"key": { "key": {
@ -159,7 +161,7 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, []byte("some value"), out.Data)
assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.Equal(t, "1bdead4badc0ffee", *out.ETag)
@ -167,7 +169,7 @@ func TestGet(t *testing.T) {
t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{ return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{ Item: map[string]*dynamodb.AttributeValue{
"key": { "key": {
@ -195,7 +197,7 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, out.Data) assert.Nil(t, out.Data)
assert.Nil(t, out.ETag) assert.Nil(t, out.ETag)
@ -203,7 +205,7 @@ func TestGet(t *testing.T) {
t.Run("Unsuccessfully get item", func(t *testing.T) { t.Run("Unsuccessfully get item", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return nil, fmt.Errorf("failed to retrieve data") return nil, fmt.Errorf("failed to retrieve data")
}, },
}, },
@ -215,14 +217,14 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Nil(t, out) assert.Nil(t, out)
}) })
t.Run("Unsuccessfully with empty response", func(t *testing.T) { t.Run("Unsuccessfully with empty response", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{ return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{}, Item: map[string]*dynamodb.AttributeValue{},
}, nil }, nil
@ -236,7 +238,7 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, out.Data) assert.Nil(t, out.Data)
assert.Nil(t, out.ETag) assert.Nil(t, out.ETag)
@ -244,7 +246,7 @@ func TestGet(t *testing.T) {
t.Run("Unsuccessfully with no required key", func(t *testing.T) { t.Run("Unsuccessfully with no required key", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) {
return &dynamodb.GetItemOutput{ return &dynamodb.GetItemOutput{
Item: map[string]*dynamodb.AttributeValue{ Item: map[string]*dynamodb.AttributeValue{
"value2": { "value2": {
@ -262,7 +264,7 @@ func TestGet(t *testing.T) {
Consistency: "strong", Consistency: "strong",
}, },
} }
out, err := ss.Get(req) out, err := ss.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Empty(t, out.Data) assert.Empty(t, out.Data)
}) })
@ -276,7 +278,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item", func(t *testing.T) { t.Run("Successfully set item", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"), S: aws.String("key"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -301,14 +303,14 @@ func TestSet(t *testing.T) {
Value: "value", Value: "value",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Successfully set item with matching etag", func(t *testing.T) { t.Run("Successfully set item with matching etag", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"), S: aws.String("key"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -339,14 +341,14 @@ func TestSet(t *testing.T) {
Value: "value", Value: "value",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"), S: aws.String("key"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -373,7 +375,7 @@ func TestSet(t *testing.T) {
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
switch tagErr := err.(type) { switch tagErr := err.(type) {
case *state.ETagError: case *state.ETagError:
@ -386,7 +388,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"), S: aws.String("key"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -415,14 +417,14 @@ func TestSet(t *testing.T) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("key"), S: aws.String("key"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -446,7 +448,7 @@ func TestSet(t *testing.T) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
switch err.(type) { switch err.(type) {
case *state.ETagError: case *state.ETagError:
@ -458,7 +460,7 @@ func TestSet(t *testing.T) {
t.Run("Successfully set item with ttl = -1", func(t *testing.T) { t.Run("Successfully set item with ttl = -1", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, len(input.Item), 4) assert.Equal(t, len(input.Item), 4)
result := DynamoDBItem{} result := DynamoDBItem{}
dynamodbattribute.UnmarshalMap(input.Item, &result) dynamodbattribute.UnmarshalMap(input.Item, &result)
@ -487,13 +489,13 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "-1", "ttlInSeconds": "-1",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, len(input.Item), 4) assert.Equal(t, len(input.Item), 4)
result := DynamoDBItem{} result := DynamoDBItem{}
dynamodbattribute.UnmarshalMap(input.Item, &result) dynamodbattribute.UnmarshalMap(input.Item, &result)
@ -522,14 +524,14 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "180", "ttlInSeconds": "180",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully set item", func(t *testing.T) { t.Run("Unsuccessfully set item", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
return nil, fmt.Errorf("unable to put item") return nil, fmt.Errorf("unable to put item")
}, },
}, },
@ -540,13 +542,13 @@ func TestSet(t *testing.T) {
Value: "value", Value: "value",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, dynamodb.AttributeValue{ assert.Equal(t, dynamodb.AttributeValue{
S: aws.String("someKey"), S: aws.String("someKey"),
}, *input.Item["key"]) }, *input.Item["key"])
@ -574,13 +576,13 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "180", "ttlInSeconds": "180",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{ assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": { "key": {
S: aws.String("somekey"), S: aws.String("somekey"),
@ -613,7 +615,7 @@ func TestSet(t *testing.T) {
"ttlInSeconds": "invalidvalue", "ttlInSeconds": "invalidvalue",
}, },
} }
err := ss.Set(req) err := ss.Set(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error())
}) })
@ -628,7 +630,7 @@ func TestBulkSet(t *testing.T) {
tableName := "table_name" tableName := "table_name"
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{} expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{ expected[tableName] = []*dynamodb.WriteRequest{
{ {
@ -688,14 +690,14 @@ func TestBulkSet(t *testing.T) {
}, },
}, },
} }
err := ss.BulkSet(req) err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Successfully set items with ttl = -1", func(t *testing.T) { t.Run("Successfully set items with ttl = -1", func(t *testing.T) {
tableName := "table_name" tableName := "table_name"
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{} expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{ expected[tableName] = []*dynamodb.WriteRequest{
{ {
@ -761,14 +763,14 @@ func TestBulkSet(t *testing.T) {
}, },
}, },
} }
err := ss.BulkSet(req) err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Successfully set items with ttl", func(t *testing.T) { t.Run("Successfully set items with ttl", func(t *testing.T) {
tableName := "table_name" tableName := "table_name"
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{} expected := map[string][]*dynamodb.WriteRequest{}
// This might fail occasionally due to timestamp precision. // This might fail occasionally due to timestamp precision.
timestamp := time.Now().Unix() + 90 timestamp := time.Now().Unix() + 90
@ -836,13 +838,13 @@ func TestBulkSet(t *testing.T) {
}, },
}, },
} }
err := ss.BulkSet(req) err := ss.BulkSet(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully set items", func(t *testing.T) { t.Run("Unsuccessfully set items", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
return nil, fmt.Errorf("unable to bulk write items") return nil, fmt.Errorf("unable to bulk write items")
}, },
}, },
@ -861,7 +863,7 @@ func TestBulkSet(t *testing.T) {
}, },
}, },
} }
err := ss.BulkSet(req) err := ss.BulkSet(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
} }
@ -874,7 +876,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{ assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": { "key": {
S: aws.String(req.Key), S: aws.String(req.Key),
@ -885,7 +887,7 @@ func TestDelete(t *testing.T) {
}, },
}, },
} }
err := ss.Delete(req) err := ss.Delete(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@ -898,7 +900,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{ assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": { "key": {
S: aws.String(req.Key), S: aws.String(req.Key),
@ -913,7 +915,7 @@ func TestDelete(t *testing.T) {
}, },
}, },
} }
err := ss.Delete(req) err := ss.Delete(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@ -926,7 +928,7 @@ func TestDelete(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
assert.Equal(t, map[string]*dynamodb.AttributeValue{ assert.Equal(t, map[string]*dynamodb.AttributeValue{
"key": { "key": {
S: aws.String(req.Key), S: aws.String(req.Key),
@ -942,7 +944,7 @@ func TestDelete(t *testing.T) {
}, },
}, },
} }
err := ss.Delete(req) err := ss.Delete(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
switch tagErr := err.(type) { switch tagErr := err.(type) {
case *state.ETagError: case *state.ETagError:
@ -955,7 +957,7 @@ func TestDelete(t *testing.T) {
t.Run("Unsuccessfully delete item", func(t *testing.T) { t.Run("Unsuccessfully delete item", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) {
return nil, fmt.Errorf("unable to delete item") return nil, fmt.Errorf("unable to delete item")
}, },
}, },
@ -963,7 +965,7 @@ func TestDelete(t *testing.T) {
req := &state.DeleteRequest{ req := &state.DeleteRequest{
Key: "key", Key: "key",
} }
err := ss.Delete(req) err := ss.Delete(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
} }
@ -973,7 +975,7 @@ func TestBulkDelete(t *testing.T) {
tableName := "table_name" tableName := "table_name"
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
expected := map[string][]*dynamodb.WriteRequest{} expected := map[string][]*dynamodb.WriteRequest{}
expected[tableName] = []*dynamodb.WriteRequest{ expected[tableName] = []*dynamodb.WriteRequest{
{ {
@ -1012,13 +1014,13 @@ func TestBulkDelete(t *testing.T) {
Key: "key2", Key: "key2",
}, },
} }
err := ss.BulkDelete(req) err := ss.BulkDelete(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Unsuccessfully delete items", func(t *testing.T) { t.Run("Unsuccessfully delete items", func(t *testing.T) {
ss := StateStore{ ss := StateStore{
client: &mockedDynamoDB{ client: &mockedDynamoDB{
BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) {
return nil, fmt.Errorf("unable to bulk write items") return nil, fmt.Errorf("unable to bulk write items")
}, },
}, },
@ -1031,7 +1033,7 @@ func TestBulkDelete(t *testing.T) {
Key: "key", Key: "key",
}, },
} }
err := ss.BulkDelete(req) err := ss.BulkDelete(context.TODO(), req)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
} }

View File

@ -130,21 +130,21 @@ func (r *StateStore) Features() []state.Feature {
} }
// Delete the state. // Delete the state.
func (r *StateStore) Delete(req *state.DeleteRequest) error { func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("delete %s", req.Key) r.logger.Debugf("delete %s", req.Key)
return r.deleteFile(context.Background(), req) return r.deleteFile(ctx, req)
} }
// Get the state. // Get the state.
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("get %s", req.Key) r.logger.Debugf("get %s", req.Key)
return r.readFile(context.Background(), req) return r.readFile(ctx, req)
} }
// Set the state. // Set the state.
func (r *StateStore) Set(req *state.SetRequest) error { func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s", req.Key) r.logger.Debugf("saving %s", req.Key)
return r.writeFile(context.Background(), req) return r.writeFile(ctx, req)
} }
func (r *StateStore) Ping() error { func (r *StateStore) Ping() error {

View File

@ -202,7 +202,7 @@ func (c *StateStore) Features() []state.Feature {
} }
// Get retrieves a CosmosDB item. // Get retrieves a CosmosDB item.
func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (c *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
partitionKey := populatePartitionMetadata(req.Key, req.Metadata) partitionKey := populatePartitionMetadata(req.Key, req.Metadata)
options := azcosmos.ItemOptions{} options := azcosmos.ItemOptions{}
@ -212,9 +212,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options) readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options)
cancel()
if err != nil { if err != nil {
var responseErr *azcore.ResponseError var responseErr *azcore.ResponseError
if errors.As(err, &responseErr) && responseErr.ErrorCode == "NotFound" { if errors.As(err, &responseErr) && responseErr.ErrorCode == "NotFound" {
@ -263,7 +261,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Set saves a CosmosDB item. // Set saves a CosmosDB item.
func (c *StateStore) Set(req *state.SetRequest) error { func (c *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -301,10 +299,8 @@ func (c *StateStore) Set(req *state.SetRequest) error {
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
pk := azcosmos.NewPartitionKeyString(partitionKey) pk := azcosmos.NewPartitionKeyString(partitionKey)
_, err = c.client.UpsertItem(ctx, pk, marsh, &options) _, err = c.client.UpsertItem(ctx, pk, marsh, &options)
cancel()
if err != nil { if err != nil {
return err return err
} }
@ -312,7 +308,7 @@ func (c *StateStore) Set(req *state.SetRequest) error {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (c *StateStore) Delete(req *state.DeleteRequest) error { func (c *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -330,10 +326,8 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error {
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
pk := azcosmos.NewPartitionKeyString(partitionKey) pk := azcosmos.NewPartitionKeyString(partitionKey)
_, err = c.client.DeleteItem(ctx, pk, req.Key, &options) _, err = c.client.DeleteItem(ctx, pk, req.Key, &options)
cancel()
if err != nil && !isNotFoundError(err) { if err != nil && !isNotFoundError(err) {
c.logger.Debugf("Error from cosmos.DeleteDocument e=%e, e.Error=%s", err, err.Error()) c.logger.Debugf("Error from cosmos.DeleteDocument e=%e, e.Error=%s", err, err.Error())
if req.ETag != nil && *req.ETag != "" { if req.ETag != nil && *req.ETag != "" {
@ -346,7 +340,7 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error {
} }
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error) { func (c *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) (err error) {
if len(request.Operations) == 0 { if len(request.Operations) == 0 {
c.logger.Debugf("No Operations Provided") c.logger.Debugf("No Operations Provided")
return nil return nil
@ -413,9 +407,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error)
c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey) c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
batchResponse, err := c.client.ExecuteTransactionalBatch(ctx, batch, nil) batchResponse, err := c.client.ExecuteTransactionalBatch(ctx, batch, nil)
cancel()
if err != nil { if err != nil {
return err return err
} }
@ -440,7 +432,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error)
return nil return nil
} }
func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (c *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
q := &Query{} q := &Query{}
qbuilder := query.NewQueryBuilder(q) qbuilder := query.NewQueryBuilder(q)
@ -448,7 +440,7 @@ func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error
return &state.QueryResponse{}, err return &state.QueryResponse{}, err
} }
data, token, err := q.execute(c.client) data, token, err := q.execute(ctx, c.client)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -144,7 +144,7 @@ func (q *Query) setNextParameter(val string) string {
return pname return pname
} }
func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) { func (q *Query) execute(ctx context.Context, client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) {
opts := &azcosmos.QueryOptions{} opts := &azcosmos.QueryOptions{}
resultLimit := q.limit resultLimit := q.limit
@ -160,9 +160,7 @@ func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, st
token := "" token := ""
for queryPager.More() { for queryPager.More() {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
queryResponse, innerErr := queryPager.NextPage(ctx) queryResponse, innerErr := queryPager.NextPage(ctx)
cancel()
if innerErr != nil { if innerErr != nil {
return nil, "", innerErr return nil, "", innerErr
} }

View File

@ -163,10 +163,10 @@ func (r *StateStore) Features() []state.Feature {
return r.features return r.features
} }
func (r *StateStore) Delete(req *state.DeleteRequest) error { func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("delete %s", req.Key) r.logger.Debugf("delete %s", req.Key)
err := r.deleteRow(req) err := r.deleteRow(ctx, req)
if err != nil { if err != nil {
if req.ETag != nil { if req.ETag != nil {
return state.NewETagError(state.ETagMismatch, err) return state.NewETagError(state.ETagMismatch, err)
@ -179,12 +179,10 @@ func (r *StateStore) Delete(req *state.DeleteRequest) error {
return err return err
} }
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("fetching %s", req.Key) r.logger.Debugf("fetching %s", req.Key)
pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode) pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode)
getContext, cancel := context.WithTimeout(context.Background(), timeout) resp, err := r.client.GetEntity(ctx, pk, rk, nil)
defer cancel()
resp, err := r.client.GetEntity(getContext, pk, rk, nil)
if err != nil { if err != nil {
if isNotFoundError(err) { if isNotFoundError(err) {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
@ -200,10 +198,10 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, err }, err
} }
func (r *StateStore) Set(req *state.SetRequest) error { func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s", req.Key) r.logger.Debugf("saving %s", req.Key)
err := r.writeRow(req) err := r.writeRow(ctx, req)
return err return err
} }
@ -254,30 +252,26 @@ func getTablesMetadata(metadata map[string]string) (*tablesMetadata, error) {
return &meta, nil return &meta, nil
} }
func (r *StateStore) writeRow(req *state.SetRequest) error { func (r *StateStore) writeRow(ctx context.Context, req *state.SetRequest) error {
marshalledEntity, err := r.marshal(req) marshalledEntity, err := r.marshal(req)
if err != nil { if err != nil {
return err return err
} }
writeContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// InsertOrReplace does not support ETag concurrency, therefore we will use Insert to check for key existence // InsertOrReplace does not support ETag concurrency, therefore we will use Insert to check for key existence
// and then use Update to update the key if it exists with the specified ETag // and then use Update to update the key if it exists with the specified ETag
_, err = r.client.AddEntity(writeContext, marshalledEntity, nil) _, err = r.client.AddEntity(ctx, marshalledEntity, nil)
if err != nil { if err != nil {
// If Insert failed because item already exists, try to Update instead per Upsert semantics // If Insert failed because item already exists, try to Update instead per Upsert semantics
if isEntityAlreadyExistsError(err) { if isEntityAlreadyExistsError(err) {
updateContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Always Update using the etag when provided even if Concurrency != FirstWrite. // Always Update using the etag when provided even if Concurrency != FirstWrite.
// Today the presence of etag takes precedence over Concurrency. // Today the presence of etag takes precedence over Concurrency.
// In the future #2739 will impose a breaking change which must disallow the use of etag when not using FirstWrite. // In the future #2739 will impose a breaking change which must disallow the use of etag when not using FirstWrite.
if req.ETag != nil && *req.ETag != "" { if req.ETag != nil && *req.ETag != "" {
etag := azcore.ETag(*req.ETag) etag := azcore.ETag(*req.ETag)
_, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{ _, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{
IfMatch: &etag, IfMatch: &etag,
UpdateMode: aztables.UpdateModeReplace, UpdateMode: aztables.UpdateModeReplace,
}) })
@ -295,7 +289,7 @@ func (r *StateStore) writeRow(req *state.SetRequest) error {
return state.NewETagError(state.ETagMismatch, errors.New("update with Concurrency.FirstWrite without ETag")) return state.NewETagError(state.ETagMismatch, errors.New("update with Concurrency.FirstWrite without ETag"))
} else { } else {
// Finally, last write semantics without ETag should always perform a force update. // Finally, last write semantics without ETag should always perform a force update.
_, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{ _, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{
IfMatch: nil, // this is the same as "*" matching all ETags IfMatch: nil, // this is the same as "*" matching all ETags
UpdateMode: aztables.UpdateModeReplace, UpdateMode: aztables.UpdateModeReplace,
}) })
@ -336,19 +330,16 @@ func isTableAlreadyExistsError(err error) bool {
return false return false
} }
func (r *StateStore) deleteRow(req *state.DeleteRequest) error { func (r *StateStore) deleteRow(ctx context.Context, req *state.DeleteRequest) error {
pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode) pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode)
deleteContext, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if req.ETag != nil { if req.ETag != nil {
azcoreETag := azcore.ETag(*req.ETag) azcoreETag := azcore.ETag(*req.ETag)
_, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag}) _, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag})
return err return err
} }
all := azcore.ETagAny all := azcore.ETagAny
_, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all}) _, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all})
return err return err
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package cassandra package cassandra
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -230,12 +231,12 @@ func getCassandraMetadata(metadata state.Metadata) (*cassandraMetadata, error) {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (c *Cassandra) Delete(req *state.DeleteRequest) error { func (c *Cassandra) Delete(ctx context.Context, req *state.DeleteRequest) error {
return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).Exec() return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Exec()
} }
// Get retrieves state from cassandra with a key. // Get retrieves state from cassandra with a key.
func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { func (c *Cassandra) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
session := c.session session := c.session
if req.Options.Consistency == state.Strong { if req.Options.Consistency == state.Strong {
@ -254,7 +255,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) {
session = sess session = sess
} }
results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).Iter().SliceMap() results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Iter().SliceMap()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -269,7 +270,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Set saves state into cassandra. // Set saves state into cassandra.
func (c *Cassandra) Set(req *state.SetRequest) error { func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error {
var bt []byte var bt []byte
b, ok := req.Value.([]byte) b, ok := req.Value.([]byte)
if ok { if ok {
@ -302,10 +303,10 @@ func (c *Cassandra) Set(req *state.SetRequest) error {
} }
if ttl != nil { if ttl != nil {
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).Exec() return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).WithContext(ctx).Exec()
} }
return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).Exec() return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).WithContext(ctx).Exec()
} }
func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session, error) { func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session, error) {

View File

@ -14,6 +14,8 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
@ -53,18 +55,18 @@ func (c *CockroachDB) Features() []state.Feature {
} }
// Delete removes an entity from the store. // Delete removes an entity from the store.
func (c *CockroachDB) Delete(req *state.DeleteRequest) error { func (c *CockroachDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
return c.dbaccess.Delete(req) return c.dbaccess.Delete(ctx, req)
} }
// Get returns an entity from store. // Get returns an entity from store.
func (c *CockroachDB) Get(req *state.GetRequest) (*state.GetResponse, error) { func (c *CockroachDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return c.dbaccess.Get(req) return c.dbaccess.Get(ctx, req)
} }
// Set adds/updates an entity on store. // Set adds/updates an entity on store.
func (c *CockroachDB) Set(req *state.SetRequest) error { func (c *CockroachDB) Set(ctx context.Context, req *state.SetRequest) error {
return c.dbaccess.Set(req) return c.dbaccess.Set(ctx, req)
} }
// Ping checks if database is available. // Ping checks if database is available.
@ -73,29 +75,29 @@ func (c *CockroachDB) Ping() error {
} }
// BulkDelete removes multiple entries from the store. // BulkDelete removes multiple entries from the store.
func (c *CockroachDB) BulkDelete(req []state.DeleteRequest) error { func (c *CockroachDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return c.dbaccess.BulkDelete(req) return c.dbaccess.BulkDelete(ctx, req)
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (c *CockroachDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (c *CockroachDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance. // TODO: replace with ExecuteMulti for performance.
return false, nil, nil return false, nil, nil
} }
// BulkSet adds/updates multiple entities on store. // BulkSet adds/updates multiple entities on store.
func (c *CockroachDB) BulkSet(req []state.SetRequest) error { func (c *CockroachDB) BulkSet(ctx context.Context, req []state.SetRequest) error {
return c.dbaccess.BulkSet(req) return c.dbaccess.BulkSet(ctx, req)
} }
// Multi handles multiple transactions. Implements TransactionalStore. // Multi handles multiple transactions. Implements TransactionalStore.
func (c *CockroachDB) Multi(request *state.TransactionalStateRequest) error { func (c *CockroachDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
return c.dbaccess.ExecuteMulti(request) return c.dbaccess.ExecuteMulti(ctx, request)
} }
// Query executes a query against store. // Query executes a query against store.
func (c *CockroachDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (c *CockroachDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return c.dbaccess.Query(req) return c.dbaccess.Query(ctx, req)
} }
// Close implements io.Closer. // Close implements io.Closer.

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -96,12 +97,12 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
} }
// Set makes an insert or update to the database. // Set makes an insert or update to the database.
func (p *cockroachDBAccess) Set(req *state.SetRequest) error { func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req) return state.SetWithOptions(ctx, p.setValue, req)
} }
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { func (p *cockroachDBAccess) setValue(ctx context.Context, req *state.SetRequest) error {
p.logger.Debug("Setting state value in CockroachDB") p.logger.Debug("Setting state value in CockroachDB")
value, isBinary, err := validateAndReturnValue(req) value, isBinary, err := validateAndReturnValue(req)
@ -114,7 +115,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names. // Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution. // Other parameters use sql.DB parameter substitution.
if req.ETag == nil { if req.ETag == nil {
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1) `INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`, ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`,
tableName), req.Key, value, isBinary) tableName), req.Key, value, isBinary)
@ -127,7 +128,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
etag := uint32(etag64) etag := uint32(etag64)
// When an etag is provided do an update - no insert. // When an etag is provided do an update - no insert.
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1 `UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1
WHERE key = $3 AND etag = $4;`, WHERE key = $3 AND etag = $4;`,
tableName), value, isBinary, req.Key, etag) tableName), value, isBinary, req.Key, etag)
@ -149,7 +150,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
return nil return nil
} }
func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { func (p *cockroachDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request") p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin() tx, err := p.db.Begin()
if err != nil { if err != nil {
@ -159,7 +160,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, s := range req { for _, s := range req {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(&sa) err = p.Set(ctx, &sa)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -174,7 +175,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error {
} }
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from CockroachDB") p.logger.Debug("Getting state value from CockroachDB")
if req.Key == "" { if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation") return nil, fmt.Errorf("missing key in get operation")
@ -183,7 +184,7 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
var value string var value string
var isBinary bool var isBinary bool
var etag int var etag int
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -222,12 +223,12 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
} }
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error { func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req) return state.DeleteWithOptions(ctx, p.deleteValue, req)
} }
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { func (p *cockroachDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from CockroachDB") p.logger.Debug("Deleting state value from CockroachDB")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return fmt.Errorf("missing key in delete operation")
@ -237,7 +238,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
var err error var err error
if req.ETag == nil { if req.ETag == nil {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
} else { } else {
var etag64 uint64 var etag64 uint64
etag64, err = strconv.ParseUint(*req.ETag, 10, 32) etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
@ -246,7 +247,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
} }
etag := uint32(etag64) etag := uint32(etag64)
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag) result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag)
} }
if err != nil { if err != nil {
@ -265,7 +266,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
return nil return nil
} }
func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { func (p *cockroachDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request") p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin() tx, err := p.db.Begin()
if err != nil { if err != nil {
@ -275,7 +276,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, d := range req { for _, d := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop. da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da) err = p.Delete(ctx, &da)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -289,7 +290,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error {
return err return err
} }
func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { func (p *cockroachDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction") p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin() tx, err := p.db.Begin()
@ -308,7 +309,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
return err return err
} }
err = p.Set(&setReq) err = p.Set(ctx, &setReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -323,7 +324,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
return err return err
} }
err = p.Delete(&delReq) err = p.Delete(ctx, &delReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -341,7 +342,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques
} }
// Query executes a query against store. // Query executes a query against store.
func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (p *cockroachDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from CockroachDB") p.logger.Debug("Getting query value from CockroachDB")
stateQuery := &Query{ stateQuery := &Query{
@ -361,7 +362,7 @@ func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse
p.logger.Debug("Query: " + stateQuery.query) p.logger.Debug("Query: " + stateQuery.query)
data, token, err := stateQuery.execute(p.logger, p.db) data, token, err := stateQuery.execute(ctx, p.logger, p.db)
if err != nil { if err != nil {
return &state.QueryResponse{ return &state.QueryResponse{
Results: []state.QueryItem{}, Results: []state.QueryItem{},

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"database/sql" "database/sql"
"testing" "testing"
@ -109,7 +110,7 @@ func TestMultiWithNoRequests(t *testing.T) {
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -133,7 +134,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -158,7 +159,7 @@ func TestValidSetRequest(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -182,7 +183,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -206,7 +207,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -231,7 +232,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -255,7 +256,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -279,7 +280,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -311,7 +312,7 @@ func TestMultiOperationOrder(t *testing.T) {
) )
// Act // Act
err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -334,7 +335,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.BulkSet(sets) err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -356,7 +357,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.BulkSet(sets) err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -379,7 +380,7 @@ func TestValidBulkSet(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.BulkSet(sets) err := m.roachDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -400,7 +401,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.BulkDelete(deletes) err := m.roachDba.BulkDelete(context.TODO(), deletes)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -422,7 +423,7 @@ func TestValidBulkDelete(t *testing.T) {
}) })
// Act // Act
err := m.roachDba.BulkDelete(deletes) err := m.roachDba.BulkDelete(context.TODO(), deletes)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -211,7 +212,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *CockroachDB) {
Consistency: "", Consistency: "",
}, },
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -239,7 +240,7 @@ func multiWithSetOnly(t *testing.T, pgs *CockroachDB) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
Metadata: nil, Metadata: nil,
}) })
@ -280,7 +281,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *CockroachDB) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
Metadata: nil, Metadata: nil,
}) })
@ -341,7 +342,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *CockroachDB) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
Metadata: nil, Metadata: nil,
}) })
@ -376,7 +377,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *CockroachDB) {
Consistency: "", Consistency: "",
}, },
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -392,7 +393,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *CockroachDB) {
Consistency: "", Consistency: "",
}, },
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -415,7 +416,7 @@ func newItemWithEtagFails(t *testing.T, pgs *CockroachDB) {
ContentType: nil, ContentType: nil,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -449,7 +450,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *CockroachDB) {
}, },
ContentType: nil, ContentType: nil,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -501,7 +502,7 @@ func getItemWithNoKey(t *testing.T, pgs *CockroachDB) {
}, },
} }
response, getErr := pgs.Get(getReq) response, getErr := pgs.Get(context.TODO(), getReq)
assert.NotNil(t, getErr) assert.NotNil(t, getErr)
assert.Nil(t, response) assert.Nil(t, response)
} }
@ -544,7 +545,7 @@ func setItemWithNoKey(t *testing.T, pgs *CockroachDB) {
ContentType: nil, ContentType: nil,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -563,7 +564,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) {
}, },
} }
err := pgs.BulkSet(setReq) err := pgs.BulkSet(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key)) assert.True(t, storeItemExists(t, setReq[1].Key))
@ -577,7 +578,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) {
}, },
} }
err = pgs.BulkDelete(deleteReq) err = pgs.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key)) assert.False(t, storeItemExists(t, setReq[1].Key))
@ -644,7 +645,7 @@ func setItem(t *testing.T, pgs *CockroachDB, key string, value interface{}, etag
ContentType: nil, ContentType: nil,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
itemExists := storeItemExists(t, key) itemExists := storeItemExists(t, key)
assert.True(t, itemExists) assert.True(t, itemExists)
@ -661,7 +662,7 @@ func getItem(t *testing.T, pgs *CockroachDB, key string) (*state.GetResponse, *f
Metadata: map[string]string{}, Metadata: map[string]string{},
} }
response, getErr := pgs.Get(getReq) response, getErr := pgs.Get(context.TODO(), getReq)
assert.Nil(t, getErr) assert.Nil(t, getErr)
assert.NotNil(t, response) assert.NotNil(t, response)
outputObject := &fakeItem{ outputObject := &fakeItem{
@ -685,7 +686,7 @@ func deleteItem(t *testing.T, pgs *CockroachDB, key string, etag *string) {
Metadata: map[string]string{}, Metadata: map[string]string{},
} }
deleteErr := pgs.Delete(deleteReq) deleteErr := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr) assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key)) assert.False(t, storeItemExists(t, key))
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strconv" "strconv"
@ -133,8 +134,8 @@ func (q *Query) Finalize(filters string, storeQuery *query.Query) error {
return nil return nil
} }
func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.Query(q.query, q.params...) rows, err := db.QueryContext(ctx, q.query, q.params...)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("query executes '%s' failed: %w", q.query, err) return nil, "", fmt.Errorf("query executes '%s' failed: %w", q.query, err)
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package cockroachdb package cockroachdb
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -42,37 +43,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil return nil
} }
func (m *fakeDBaccess) Set(req *state.SetRequest) error { func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true m.setExecuted = true
return nil return nil
} }
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true m.getExecuted = true
return nil, nil return nil, nil
} }
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
m.deleteExecuted = true m.deleteExecuted = true
return nil return nil
} }
func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return nil, nil return nil, nil
} }

View File

@ -13,18 +13,22 @@ limitations under the License.
package cockroachdb package cockroachdb
import "github.com/dapr/components-contrib/state" import (
"context"
"github.com/dapr/components-contrib/state"
)
// dbAccess is a private interface which enables unit testing of CockroachDB. // dbAccess is a private interface which enables unit testing of CockroachDB.
type dbAccess interface { type dbAccess interface {
Init(metadata state.Metadata) error Init(metadata state.Metadata) error
Set(req *state.SetRequest) error Set(ctx context.Context, req *state.SetRequest) error
BulkSet(req []state.SetRequest) error BulkSet(ctx context.Context, req []state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error Delete(ctx context.Context, req *state.DeleteRequest) error
BulkDelete(req []state.DeleteRequest) error BulkDelete(ctx context.Context, req []state.DeleteRequest) error
ExecuteMulti(req *state.TransactionalStateRequest) error ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error
Query(req *state.QueryRequest) (*state.QueryResponse, error) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
Ping() error Ping() error
Close() error Close() error
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package couchbase package couchbase
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -144,7 +145,7 @@ func (cbs *Couchbase) Features() []state.Feature {
} }
// Set stores value for a key to couchbase. It honors ETag (for concurrency) and consistency settings. // Set stores value for a key to couchbase. It honors ETag (for concurrency) and consistency settings.
func (cbs *Couchbase) Set(req *state.SetRequest) error { func (cbs *Couchbase) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -188,7 +189,7 @@ func (cbs *Couchbase) Set(req *state.SetRequest) error {
} }
// Get retrieves state from couchbase with a key. // Get retrieves state from couchbase with a key.
func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) { func (cbs *Couchbase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
var data interface{} var data interface{}
cas, err := cbs.bucket.Get(req.Key, &data) cas, err := cbs.bucket.Get(req.Key, &data)
if err != nil { if err != nil {
@ -206,7 +207,7 @@ func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (cbs *Couchbase) Delete(req *state.DeleteRequest) error { func (cbs *Couchbase) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err

View File

@ -93,12 +93,12 @@ func (f *Firestore) Features() []state.Feature {
} }
// Get retrieves state from Firestore with a key (Always strong consistency). // Get retrieves state from Firestore with a key (Always strong consistency).
func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (f *Firestore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
key := req.Key key := req.Key
entityKey := datastore.NameKey(f.entityKind, key, nil) entityKey := datastore.NameKey(f.entityKind, key, nil)
var entity StateEntity var entity StateEntity
err := f.client.Get(context.Background(), entityKey, &entity) err := f.client.Get(ctx, entityKey, &entity)
if err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { if err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) {
return nil, err return nil, err
@ -111,7 +111,7 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil }, nil
} }
func (f *Firestore) setValue(req *state.SetRequest) error { func (f *Firestore) setValue(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -128,7 +128,6 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
entity := &StateEntity{ entity := &StateEntity{
Value: v, Value: v,
} }
ctx := context.Background()
key := datastore.NameKey(f.entityKind, req.Key, nil) key := datastore.NameKey(f.entityKind, req.Key, nil)
_, err = f.client.Put(ctx, key, entity) _, err = f.client.Put(ctx, key, entity)
@ -141,12 +140,11 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
} }
// Set saves state into Firestore with retry. // Set saves state into Firestore with retry.
func (f *Firestore) Set(req *state.SetRequest) error { func (f *Firestore) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(f.setValue, req) return state.SetWithOptions(ctx, f.setValue, req)
} }
func (f *Firestore) deleteValue(req *state.DeleteRequest) error { func (f *Firestore) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
ctx := context.Background()
key := datastore.NameKey(f.entityKind, req.Key, nil) key := datastore.NameKey(f.entityKind, req.Key, nil)
err := f.client.Delete(ctx, key) err := f.client.Delete(ctx, key)
@ -158,8 +156,8 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (f *Firestore) Delete(req *state.DeleteRequest) error { func (f *Firestore) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(f.deleteValue, req) return state.DeleteWithOptions(ctx, f.deleteValue, req)
} }
func getFirestoreMetadata(metadata state.Metadata) (*firestoreMetadata, error) { func getFirestoreMetadata(metadata state.Metadata) (*firestoreMetadata, error) {

View File

@ -14,6 +14,7 @@ limitations under the License.
package consul package consul
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -102,11 +103,12 @@ func metadataToConfig(connInfo map[string]string) (*consulConfig, error) {
} }
// Get retrieves a Consul KV item. // Get retrieves a Consul KV item.
func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) { func (c *Consul) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
queryOpts := &api.QueryOptions{} queryOpts := &api.QueryOptions{}
if req.Options.Consistency == state.Strong { if req.Options.Consistency == state.Strong {
queryOpts.RequireConsistent = true queryOpts.RequireConsistent = true
} }
queryOpts = queryOpts.WithContext(ctx)
resp, queryMeta, err := c.client.KV().Get(fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key), queryOpts) resp, queryMeta, err := c.client.KV().Get(fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key), queryOpts)
if err != nil { if err != nil {
@ -124,7 +126,7 @@ func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Set saves a Consul KV item. // Set saves a Consul KV item.
func (c *Consul) Set(req *state.SetRequest) error { func (c *Consul) Set(ctx context.Context, req *state.SetRequest) error {
var reqValByte []byte var reqValByte []byte
b, ok := req.Value.([]byte) b, ok := req.Value.([]byte)
if ok { if ok {
@ -135,10 +137,12 @@ func (c *Consul) Set(req *state.SetRequest) error {
keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key) keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key)
writeOptions := new(api.WriteOptions)
writeOptions = writeOptions.WithContext(ctx)
_, err := c.client.KV().Put(&api.KVPair{ _, err := c.client.KV().Put(&api.KVPair{
Key: keyWithPath, Key: keyWithPath,
Value: reqValByte, Value: reqValByte,
}, nil) }, writeOptions)
if err != nil { if err != nil {
return fmt.Errorf("couldn't set key %s: %s", keyWithPath, err) return fmt.Errorf("couldn't set key %s: %s", keyWithPath, err)
} }
@ -147,9 +151,11 @@ func (c *Consul) Set(req *state.SetRequest) error {
} }
// Delete performes a Consul KV delete operation. // Delete performes a Consul KV delete operation.
func (c *Consul) Delete(req *state.DeleteRequest) error { func (c *Consul) Delete(ctx context.Context, req *state.DeleteRequest) error {
keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key) keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key)
_, err := c.client.KV().Delete(keyWithPath, nil) writeOptions := new(api.WriteOptions)
writeOptions = writeOptions.WithContext(ctx)
_, err := c.client.KV().Delete(keyWithPath, writeOptions)
if err != nil { if err != nil {
return fmt.Errorf("couldn't delete key %s: %s", keyWithPath, err) return fmt.Errorf("couldn't delete key %s: %s", keyWithPath, err)
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package hazelcast package hazelcast
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -91,7 +92,7 @@ func (store *Hazelcast) Features() []state.Feature {
} }
// Set stores value for a key to Hazelcast. // Set stores value for a key to Hazelcast.
func (store *Hazelcast) Set(req *state.SetRequest) error { func (store *Hazelcast) Set(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req) err := state.CheckRequestOptions(req)
if err != nil { if err != nil {
return err return err
@ -117,7 +118,7 @@ func (store *Hazelcast) Set(req *state.SetRequest) error {
} }
// Get retrieves state from Hazelcast with a key. // Get retrieves state from Hazelcast with a key.
func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) { func (store *Hazelcast) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
resp, err := store.hzMap.Get(req.Key) resp, err := store.hzMap.Get(req.Key)
if err != nil { if err != nil {
return nil, fmt.Errorf("hazelcast error: failed to get value for %s: %s", req.Key, err) return nil, fmt.Errorf("hazelcast error: failed to get value for %s: %s", req.Key, err)
@ -138,7 +139,7 @@ func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (store *Hazelcast) Delete(req *state.DeleteRequest) error { func (store *Hazelcast) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err

View File

@ -73,7 +73,7 @@ func (store *inMemoryStore) Features() []state.Feature {
return []state.Feature{state.FeatureETag, state.FeatureTransactional} return []state.Feature{state.FeatureETag, state.FeatureTransactional}
} }
func (store *inMemoryStore) Delete(req *state.DeleteRequest) error { func (store *inMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
// step1: validate parameters // step1: validate parameters
if err := store.doDeleteValidateParameters(req); err != nil { if err := store.doDeleteValidateParameters(req); err != nil {
return err return err
@ -90,7 +90,7 @@ func (store *inMemoryStore) Delete(req *state.DeleteRequest) error {
// step3: do really delete // step3: do really delete
// this operation won't fail // this operation won't fail
store.doDelete(req.Key) store.doDelete(ctx, req.Key)
return nil return nil
} }
@ -117,11 +117,11 @@ func (store *inMemoryStore) doValidateEtag(key string, etag *string, concurrency
return nil return nil
} }
func (store *inMemoryStore) doDelete(key string) { func (store *inMemoryStore) doDelete(ctx context.Context, key string) {
delete(store.items, key) delete(store.items, key)
} }
func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error { func (store *inMemoryStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
if len(req) == 0 { if len(req) == 0 {
return nil return nil
} }
@ -148,15 +148,15 @@ func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error {
// step3: do really delete // step3: do really delete
for _, dr := range req { for _, dr := range req {
store.doDelete(dr.Key) store.doDelete(ctx, dr.Key)
} }
return nil return nil
} }
func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (store *inMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
item := store.doGetWithReadLock(req.Key) item := store.doGetWithReadLock(ctx, req.Key)
if item != nil && isExpired(item.expire) { if item != nil && isExpired(item.expire) {
item = store.doGetWithWriteLock(req.Key) item = store.doGetWithWriteLock(ctx, req.Key)
} }
if item == nil { if item == nil {
@ -165,14 +165,14 @@ func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, erro
return &state.GetResponse{Data: unmarshal(item.data), ETag: item.etag}, nil return &state.GetResponse{Data: unmarshal(item.data), ETag: item.etag}, nil
} }
func (store *inMemoryStore) doGetWithReadLock(key string) *inMemStateStoreItem { func (store *inMemoryStore) doGetWithReadLock(ctx context.Context, key string) *inMemStateStoreItem {
store.lock.RLock() store.lock.RLock()
defer store.lock.RUnlock() defer store.lock.RUnlock()
return store.items[key] return store.items[key]
} }
func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem { func (store *inMemoryStore) doGetWithWriteLock(ctx context.Context, key string) *inMemStateStoreItem {
store.lock.Lock() store.lock.Lock()
defer store.lock.Unlock() defer store.lock.Unlock()
// get item and check expired again to avoid if item changed between we got this write-lock // get item and check expired again to avoid if item changed between we got this write-lock
@ -181,7 +181,7 @@ func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem
return nil return nil
} }
if isExpired(item.expire) { if isExpired(item.expire) {
store.doDelete(key) store.doDelete(ctx, key)
return nil return nil
} }
return item return item
@ -194,11 +194,11 @@ func isExpired(expire int64) bool {
return time.Now().UnixMilli() > expire return time.Now().UnixMilli() > expire
} }
func (store *inMemoryStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (store *inMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil return false, nil, nil
} }
func (store *inMemoryStore) Set(req *state.SetRequest) error { func (store *inMemoryStore) Set(ctx context.Context, req *state.SetRequest) error {
// step1: validate parameters // step1: validate parameters
ttlInSeconds, err := store.doSetValidateParameters(req) ttlInSeconds, err := store.doSetValidateParameters(req)
if err != nil { if err != nil {
@ -217,7 +217,7 @@ func (store *inMemoryStore) Set(req *state.SetRequest) error {
// step3: do really set // step3: do really set
// this operation won't fail // this operation won't fail
store.doSet(req.Key, b, req.ETag, ttlInSeconds) store.doSet(ctx, req.Key, b, req.ETag, ttlInSeconds)
return nil return nil
} }
@ -253,7 +253,7 @@ func doParseTTLInSeconds(metadata map[string]string) (int, error) {
return i, nil return i, nil
} }
func (store *inMemoryStore) doSet(key string, data []byte, etag *string, ttlInSeconds int) { func (store *inMemoryStore) doSet(ctx context.Context, key string, data []byte, etag *string, ttlInSeconds int) {
store.items[key] = &inMemStateStoreItem{ store.items[key] = &inMemStateStoreItem{
data: data, data: data,
etag: etag, etag: etag,
@ -268,7 +268,7 @@ type innerSetRequest struct {
data []byte data []byte
} }
func (store *inMemoryStore) BulkSet(req []state.SetRequest) error { func (store *inMemoryStore) BulkSet(ctx context.Context, req []state.SetRequest) error {
if len(req) == 0 { if len(req) == 0 {
return nil return nil
} }
@ -305,12 +305,12 @@ func (store *inMemoryStore) BulkSet(req []state.SetRequest) error {
// step3: do really set // step3: do really set
// these operations won't fail // these operations won't fail
for _, innerSetRequest := range innerSetRequestList { for _, innerSetRequest := range innerSetRequestList {
store.doSet(innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds) store.doSet(ctx, innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds)
} }
return nil return nil
} }
func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) error { func (store *inMemoryStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
if len(request.Operations) == 0 { if len(request.Operations) == 0 {
return nil return nil
} }
@ -366,10 +366,10 @@ func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) erro
for _, o := range request.Operations { for _, o := range request.Operations {
if o.Operation == state.Upsert { if o.Operation == state.Upsert {
s := o.Request.(innerSetRequest) s := o.Request.(innerSetRequest)
store.doSet(s.req.Key, s.data, s.req.ETag, s.ttlInSeconds) store.doSet(ctx, s.req.Key, s.data, s.req.ETag, s.ttlInSeconds)
} else if o.Operation == state.Delete { } else if o.Operation == state.Delete {
d := o.Request.(state.DeleteRequest) d := o.Request.(state.DeleteRequest)
store.doDelete(d.Key) store.doDelete(ctx, d.Key)
} }
} }
return nil return nil
@ -406,7 +406,7 @@ func (store *inMemoryStore) doCleanExpiredItems() {
for key, item := range store.items { for key, item := range store.items {
if isExpired(item.expire) { if isExpired(item.expire) {
store.doDelete(key) store.doDelete(context.Background(), key)
} }
} }
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package inmemory package inmemory
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -43,13 +44,13 @@ func TestReadAndWrite(t *testing.T) {
Value: valueA, Value: valueA,
ETag: ptr.String("the etag"), ETag: ptr.String("the etag"),
} }
err := store.Set(setReq) err := store.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
// get after set // get after set
getReq := &state.GetRequest{ getReq := &state.GetRequest{
Key: keyA, Key: keyA,
} }
resp, err := store.Get(getReq) resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, valueA, string(resp.Data)) assert.Equal(t, valueA, string(resp.Data))
@ -62,7 +63,7 @@ func TestReadAndWrite(t *testing.T) {
Value: valueA, Value: valueA,
Metadata: map[string]string{"ttlInSeconds": "1"}, Metadata: map[string]string{"ttlInSeconds": "1"},
} }
err := store.Set(setReq) err := store.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
// simulate expiration // simulate expiration
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
@ -70,7 +71,7 @@ func TestReadAndWrite(t *testing.T) {
getReq := &state.GetRequest{ getReq := &state.GetRequest{
Key: keyA, Key: keyA,
} }
resp, err := store.Get(getReq) resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Nil(t, resp.Data) assert.Nil(t, resp.Data)
@ -84,20 +85,20 @@ func TestReadAndWrite(t *testing.T) {
Value: "1234", Value: "1234",
ETag: ptr.String("the etag"), ETag: ptr.String("the etag"),
} }
err := store.Set(setReq) err := store.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
// get // get
getReq := &state.GetRequest{ getReq := &state.GetRequest{
Key: "theSecondKey", Key: "theSecondKey",
} }
resp, err := store.Get(getReq) resp, err := store.Get(context.TODO(), getReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "1234", string(resp.Data)) assert.Equal(t, "1234", string(resp.Data))
}) })
t.Run("BulkSet two keys", func(t *testing.T) { t.Run("BulkSet two keys", func(t *testing.T) {
err := store.BulkSet([]state.SetRequest{{ err := store.BulkSet(context.TODO(), []state.SetRequest{{
Key: "theFirstKey", Key: "theFirstKey",
Value: "666", Value: "666",
}, { }, {
@ -109,7 +110,7 @@ func TestReadAndWrite(t *testing.T) {
}) })
t.Run("BulkGet fails when not supported", func(t *testing.T) { t.Run("BulkGet fails when not supported", func(t *testing.T) {
supportBulk, _, err := store.BulkGet([]state.GetRequest{{ supportBulk, _, err := store.BulkGet(context.TODO(), []state.GetRequest{{
Key: "theFirstKey", Key: "theFirstKey",
}, { }, {
Key: "theSecondKey", Key: "theSecondKey",
@ -123,7 +124,7 @@ func TestReadAndWrite(t *testing.T) {
req := &state.DeleteRequest{ req := &state.DeleteRequest{
Key: "theFirstKey", Key: "theFirstKey",
} }
err := store.Delete(req) err := store.Delete(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
}) })
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package jetstream package jetstream
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -97,7 +98,7 @@ func (js *StateStore) Features() []state.Feature {
} }
// Get retrieves state with a key. // Get retrieves state with a key.
func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (js *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
entry, err := js.bucket.Get(escape(req.Key)) entry, err := js.bucket.Get(escape(req.Key))
if err != nil { if err != nil {
return nil, err return nil, err
@ -109,14 +110,14 @@ func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Set stores value for a key. // Set stores value for a key.
func (js *StateStore) Set(req *state.SetRequest) error { func (js *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
bt, _ := utils.Marshal(req.Value, js.json.Marshal) bt, _ := utils.Marshal(req.Value, js.json.Marshal)
_, err := js.bucket.Put(escape(req.Key), bt) _, err := js.bucket.Put(escape(req.Key), bt)
return err return err
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (js *StateStore) Delete(req *state.DeleteRequest) error { func (js *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
return js.bucket.Delete(escape(req.Key)) return js.bucket.Delete(escape(req.Key))
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package jetstream package jetstream
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -112,7 +113,7 @@ func TestSetGetAndDelete(t *testing.T) {
"dkey": "dvalue", "dkey": "dvalue",
} }
err = store.Set(&state.SetRequest{ err = store.Set(context.TODO(), &state.SetRequest{
Key: tkey, Key: tkey,
Value: tData, Value: tData,
}) })
@ -121,7 +122,7 @@ func TestSetGetAndDelete(t *testing.T) {
return return
} }
resp, err := store.Get(&state.GetRequest{ resp, err := store.Get(context.TODO(), &state.GetRequest{
Key: tkey, Key: tkey,
}) })
if err != nil { if err != nil {
@ -134,7 +135,7 @@ func TestSetGetAndDelete(t *testing.T) {
t.Fatal("Response data does not match written data\n") t.Fatal("Response data does not match written data\n")
} }
err = store.Delete(&state.DeleteRequest{ err = store.Delete(context.TODO(), &state.DeleteRequest{
Key: tkey, Key: tkey,
}) })
if err != nil { if err != nil {
@ -142,7 +143,7 @@ func TestSetGetAndDelete(t *testing.T) {
return return
} }
_, err = store.Get(&state.GetRequest{ _, err = store.Get(context.TODO(), &state.GetRequest{
Key: tkey, Key: tkey,
}) })
if err == nil { if err == nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package memcached package memcached
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -139,7 +140,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) {
return nil, nil return nil, nil
} }
func (m *Memcached) setValue(req *state.SetRequest) error { func (m *Memcached) setValue(ctx context.Context, req *state.SetRequest) error {
var bt []byte var bt []byte
ttl, err := m.parseTTL(req) ttl, err := m.parseTTL(req)
if err != nil { if err != nil {
@ -159,7 +160,7 @@ func (m *Memcached) setValue(req *state.SetRequest) error {
return nil return nil
} }
func (m *Memcached) Delete(req *state.DeleteRequest) error { func (m *Memcached) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := m.client.Delete(req.Key) err := m.client.Delete(req.Key)
if err != nil { if err != nil {
if err == memcache.ErrCacheMiss { if err == memcache.ErrCacheMiss {
@ -171,7 +172,7 @@ func (m *Memcached) Delete(req *state.DeleteRequest) error {
return nil return nil
} }
func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *Memcached) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
item, err := m.client.Get(req.Key) item, err := m.client.Get(req.Key)
if err != nil { if err != nil {
// Return nil for status 204 // Return nil for status 204
@ -187,6 +188,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil }, nil
} }
func (m *Memcached) Set(req *state.SetRequest) error { func (m *Memcached) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req) return state.SetWithOptions(ctx, m.setValue, req)
} }

View File

@ -159,10 +159,7 @@ func (m *MongoDB) Features() []state.Feature {
} }
// Set saves state into MongoDB. // Set saves state into MongoDB.
func (m *MongoDB) Set(req *state.SetRequest) error { func (m *MongoDB) Set(ctx context.Context, req *state.SetRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
err := m.setInternal(ctx, req) err := m.setInternal(ctx, req)
if err != nil { if err != nil {
return err return err
@ -205,12 +202,9 @@ func (m *MongoDB) setInternal(ctx context.Context, req *state.SetRequest) error
} }
// Get retrieves state from MongoDB with a key. // Get retrieves state from MongoDB with a key.
func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *MongoDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
var result Item var result Item
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
filter := bson.M{id: req.Key} filter := bson.M{id: req.Key}
err := m.collection.FindOne(ctx, filter).Decode(&result) err := m.collection.FindOne(ctx, filter).Decode(&result)
if err != nil { if err != nil {
@ -264,10 +258,7 @@ func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (m *MongoDB) Delete(req *state.DeleteRequest) error { func (m *MongoDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
err := m.deleteInternal(ctx, req) err := m.deleteInternal(ctx, req)
if err != nil { if err != nil {
return err return err
@ -294,18 +285,18 @@ func (m *MongoDB) deleteInternal(ctx context.Context, req *state.DeleteRequest)
} }
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (m *MongoDB) Multi(request *state.TransactionalStateRequest) error { func (m *MongoDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
sess, err := m.client.StartSession() sess, err := m.client.StartSession()
txnOpts := options.Transaction().SetReadConcern(readconcern.Snapshot()). txnOpts := options.Transaction().SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.New(writeconcern.WMajority())) SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
defer sess.EndSession(context.Background()) defer sess.EndSession(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error in starting the transaction: %s", err) return fmt.Errorf("error in starting the transaction: %s", err)
} }
sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) { sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) {
err = m.doTransaction(sessCtx, request.Operations) err = m.doTransaction(sessCtx, request.Operations)
return nil, err return nil, err
@ -336,10 +327,7 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state
} }
// Query executes a query against store. // Query executes a query against store.
func (m *MongoDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (m *MongoDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout)
defer cancel()
q := &Query{} q := &Query{}
qbuilder := query.NewQueryBuilder(q) qbuilder := query.NewQueryBuilder(q)
if err := qbuilder.BuildQuery(&req.Query); err != nil { if err := qbuilder.BuildQuery(&req.Query); err != nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -279,13 +280,13 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
// Delete removes an entity from the store // Delete removes an entity from the store
// Store Interface. // Store Interface.
func (m *MySQL) Delete(req *state.DeleteRequest) error { func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(m.deleteValue, req) return state.DeleteWithOptions(ctx, m.deleteValue, req)
} }
// deleteValue is an internal implementation of delete to enable passing the // deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func. // logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(req *state.DeleteRequest) error { func (m *MySQL) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql") m.logger.Debug("Deleting state value from MySql")
if req.Key == "" { if req.Key == "" {
@ -296,11 +297,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
var result sql.Result var result sql.Result
if req.ETag == nil || *req.ETag == "" { if req.ETag == nil || *req.ETag == "" {
result, err = m.db.Exec(fmt.Sprintf( result, err = m.db.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ?`, `DELETE FROM %s WHERE id = ?`,
m.tableName), req.Key) m.tableName), req.Key)
} else { } else {
result, err = m.db.Exec(fmt.Sprintf( result, err = m.db.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ? and eTag = ?`, `DELETE FROM %s WHERE id = ? and eTag = ?`,
m.tableName), req.Key, *req.ETag) m.tableName), req.Key, *req.ETag)
} }
@ -323,7 +324,7 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
// BulkDelete removes multiple entries from the store // BulkDelete removes multiple entries from the store
// Store Interface. // Store Interface.
func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
m.logger.Debug("Executing BulkDelete request") m.logger.Debug("Executing BulkDelete request")
tx, err := m.db.Begin() tx, err := m.db.Begin()
@ -334,7 +335,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, d := range req { for _, d := range req {
da := d // Fix for goSec G601: Implicit memory aliasing in for loop. da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Delete(&da) err = m.Delete(ctx, &da)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -350,7 +351,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
// Get returns an entity from store // Get returns an entity from store
// Store Interface. // Store Interface.
func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *MySQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.logger.Debug("Getting state value from MySql") m.logger.Debug("Getting state value from MySql")
if req.Key == "" { if req.Key == "" {
@ -368,7 +369,7 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`, `SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
err := m.db.QueryRow(query, req.Key).Scan(&value, &eTag, &isBinary) err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return an error. // If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -410,13 +411,13 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
// Set adds/updates an entity on store // Set adds/updates an entity on store
// Store Interface. // Store Interface.
func (m *MySQL) Set(req *state.SetRequest) error { func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req) return state.SetWithOptions(ctx, m.setValue, req)
} }
// setValue is an internal implementation of set to enable passing the logic // setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func. // to state.SetWithRetries as a func.
func (m *MySQL) setValue(req *state.SetRequest) error { func (m *MySQL) setValue(ctx context.Context, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql") m.logger.Debug("Setting state value in MySql")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
@ -457,7 +458,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`, `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary) result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
} else if req.ETag != nil && *req.ETag != "" { } else if req.ETag != nil && *req.ETag != "" {
// When an eTag is provided do an update - not insert // When an eTag is provided do an update - not insert
//nolint:gosec //nolint:gosec
@ -465,7 +466,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`, `UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.Exec(query, enc, eTag, isBinary, req.Key, *req.ETag) result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
} else { } else {
// If this is a duplicate MySQL returns that two rows affected // If this is a duplicate MySQL returns that two rows affected
maxRows = 2 maxRows = 2
@ -474,7 +475,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
} }
if err != nil { if err != nil {
@ -508,7 +509,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
// BulkSet adds/updates multiple entities on store // BulkSet adds/updates multiple entities on store
// Store Interface. // Store Interface.
func (m *MySQL) BulkSet(req []state.SetRequest) error { func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
m.logger.Debug("Executing BulkSet request") m.logger.Debug("Executing BulkSet request")
tx, err := m.db.Begin() tx, err := m.db.Begin()
@ -518,7 +519,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
if len(req) > 0 { if len(req) > 0 {
for i := range req { for i := range req {
err = m.Set(&req[i]) err = m.Set(ctx, &req[i])
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -531,7 +532,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
// Multi handles multiple transactions. // Multi handles multiple transactions.
// TransactionalStore Interface. // TransactionalStore Interface.
func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
m.logger.Debug("Executing Multi request") m.logger.Debug("Executing Multi request")
tx, err := m.db.Begin() tx, err := m.db.Begin()
@ -548,7 +549,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
return err return err
} }
err = m.Set(&setReq) err = m.Set(ctx, &setReq)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return err return err
@ -561,7 +562,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
return err return err
} }
err = m.Delete(&delReq) err = m.Delete(ctx, &delReq)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return err return err
@ -604,7 +605,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (m *MySQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (m *MySQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// by default, the store doesn't support bulk get // by default, the store doesn't support bulk get
// return false so daprd will fallback to call get() method one by one // return false so daprd will fallback to call get() method one by one
return false, nil, nil return false, nil, nil

View File

@ -15,6 +15,7 @@ limitations under the License.
package mysql package mysql
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
@ -205,7 +206,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "", Key: "",
} }
response, getErr := mys.Get(getReq) response, getErr := mys.Get(context.TODO(), getReq)
assert.NotNil(t, getErr) assert.NotNil(t, getErr)
assert.Nil(t, response) assert.Nil(t, response)
}) })
@ -242,7 +243,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "", Key: "",
} }
err := mys.Set(setReq) err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Error was not nil when setting item with no key.") assert.NotNil(t, err, "Error was not nil when setting item with no key.")
}) })
@ -302,7 +303,7 @@ func TestMySQLIntegration(t *testing.T) {
Value: newValue, Value: newValue,
} }
err := mys.Set(setReq) err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Error was not thrown using old eTag") assert.NotNil(t, err, "Error was not thrown using old eTag")
}) })
@ -318,7 +319,7 @@ func TestMySQLIntegration(t *testing.T) {
Value: value, Value: value,
} }
err := mys.Set(setReq) err := mys.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
@ -338,7 +339,7 @@ func TestMySQLIntegration(t *testing.T) {
ETag: &eTag, ETag: &eTag,
} }
err := mys.Delete(deleteReq) err := mys.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
@ -349,7 +350,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: "", Key: "",
} }
err := mys.Delete(deleteReq) err := mys.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
@ -361,7 +362,7 @@ func TestMySQLIntegration(t *testing.T) {
Key: randomKey(), Key: randomKey(),
} }
err := mys.Delete(deleteReq) err := mys.Delete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@ -378,7 +379,7 @@ func TestMySQLIntegration(t *testing.T) {
}, },
} }
err := mys.Set(setReq) err := mys.Set(context.TODO(), setReq)
assert.NoError(t, err) assert.NoError(t, err)
// Get the etag // Get the etag
@ -396,7 +397,7 @@ func TestMySQLIntegration(t *testing.T) {
}, },
} }
err = mys.Set(setReq) err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "Duplicate entry") assert.ErrorContains(t, err, "Duplicate entry")
// Insert with invalid etag should fail on existing keys // Insert with invalid etag should fail on existing keys
@ -409,7 +410,7 @@ func TestMySQLIntegration(t *testing.T) {
}, },
} }
err = mys.Set(setReq) err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "possible etag mismatch") assert.ErrorContains(t, err, "possible etag mismatch")
// Insert with valid etag should succeed on existing keys // Insert with valid etag should succeed on existing keys
@ -422,7 +423,7 @@ func TestMySQLIntegration(t *testing.T) {
}, },
} }
err = mys.Set(setReq) err = mys.Set(context.TODO(), setReq)
assert.NoError(t, err) assert.NoError(t, err)
// Insert with an etag should fail on new keys // Insert with an etag should fail on new keys
@ -435,7 +436,7 @@ func TestMySQLIntegration(t *testing.T) {
}, },
} }
err = mys.Set(setReq) err = mys.Set(context.TODO(), setReq)
assert.ErrorContains(t, err, "possible etag mismatch") assert.ErrorContains(t, err, "possible etag mismatch")
}) })
@ -474,7 +475,7 @@ func TestMySQLIntegration(t *testing.T) {
}) })
} }
err := mys.Multi(&state.TransactionalStateRequest{ err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -510,7 +511,7 @@ func TestMySQLIntegration(t *testing.T) {
}) })
} }
err := mys.Multi(&state.TransactionalStateRequest{ err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -537,7 +538,7 @@ func TestMySQLIntegration(t *testing.T) {
}) })
} }
err := mys.Multi(&state.TransactionalStateRequest{ err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -562,7 +563,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
}, },
} }
err := mys.BulkSet(setReq) err := mys.BulkSet(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key)) assert.True(t, storeItemExists(t, setReq[1].Key))
@ -576,7 +577,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
}, },
} }
err = mys.BulkDelete(deleteReq) err = mys.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key)) assert.False(t, storeItemExists(t, setReq[1].Key))
@ -596,7 +597,7 @@ func setItem(t *testing.T, mys *MySQL, key string, value interface{}, eTag *stri
Value: value, Value: value,
} }
err := mys.Set(setReq) err := mys.Set(context.TODO(), setReq)
assert.Nil(t, err, "Error setting an item") assert.Nil(t, err, "Error setting an item")
itemExists := storeItemExists(t, key) itemExists := storeItemExists(t, key)
assert.True(t, itemExists, "Item does not exist after being set") assert.True(t, itemExists, "Item does not exist after being set")
@ -608,7 +609,7 @@ func getItem(t *testing.T, mys *MySQL, key string) (*state.GetResponse, *fakeIte
Options: state.GetStateOption{}, Options: state.GetStateOption{},
} }
response, getErr := mys.Get(getReq) response, getErr := mys.Get(context.TODO(), getReq)
assert.Nil(t, getErr) assert.Nil(t, getErr)
assert.NotNil(t, response) assert.NotNil(t, response)
outputObject := &fakeItem{} outputObject := &fakeItem{}
@ -624,7 +625,7 @@ func deleteItem(t *testing.T, mys *MySQL, key string, eTag *string) {
Options: state.DeleteStateOption{}, Options: state.DeleteStateOption{},
} }
deleteErr := mys.Delete(deleteReq) deleteErr := mys.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr, "There was an error deleting a record") assert.Nil(t, deleteErr, "There was an error deleting a record")
assert.False(t, storeItemExists(t, key), "Item still exists after delete") assert.False(t, storeItemExists(t, key), "Item still exists after delete")
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -161,7 +162,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) {
m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError")) m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError"))
// Act // Act
err := m.mySQL.Multi(nil) err := m.mySQL.Multi(context.TODO(), nil)
// Assert // Assert
assert.NotNil(t, err, "no error returned") assert.NotNil(t, err, "no error returned")
@ -180,7 +181,7 @@ func TestMySQLBulkDeleteRollbackDeletes(t *testing.T) {
deletes := []state.DeleteRequest{createDeleteRequest()} deletes := []state.DeleteRequest{createDeleteRequest()}
// Act // Act
err := m.mySQL.BulkDelete(deletes) err := m.mySQL.BulkDelete(context.TODO(), deletes)
// Assert // Assert
assert.NotNil(t, err, "no error returned") assert.NotNil(t, err, "no error returned")
@ -199,7 +200,7 @@ func TestMySQLBulkSetRollbackSets(t *testing.T) {
sets := []state.SetRequest{createSetRequest()} sets := []state.SetRequest{createSetRequest()}
// Act // Act
err := m.mySQL.BulkSet(sets) err := m.mySQL.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.NotNil(t, err, "no error returned") assert.NotNil(t, err, "no error returned")
@ -232,7 +233,7 @@ func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) {
} }
// Act // Act
err := m.mySQL.Multi(&request) err := m.mySQL.Multi(context.TODO(), &request)
// Assert // Assert
assert.Nil(t, err, "error returned") assert.Nil(t, err, "error returned")
@ -248,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) {
request.Options.Consistency = "Invalid" request.Options.Consistency = "Invalid"
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -263,7 +264,7 @@ func TestSetHandlesNoKey(t *testing.T) {
request.Key = "" request.Key = ""
// Act // Act
err := m.mySQL.Set(&request) err := m.mySQL.Set(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -283,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -302,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -315,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -327,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -338,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -352,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -369,7 +370,7 @@ func TestMySQLDeleteHandlesNoKey(t *testing.T) {
request.Key = "" request.Key = ""
// Act // Act
err := m.mySQL.Delete(&request) err := m.mySQL.Delete(context.TODO(), &request)
// Asset // Asset
assert.NotNil(t, err) assert.NotNil(t, err)
@ -388,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -405,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) {
request := createDeleteRequest() request := createDeleteRequest()
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -420,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(context.TODO(), &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -441,7 +442,7 @@ func TestGetHandlesNoRows(t *testing.T) {
} }
// Act // Act
response, err := m.mySQL.Get(request) response, err := m.mySQL.Get(context.TODO(), request)
// Assert // Assert
assert.Nil(t, err, "returned error") assert.Nil(t, err, "returned error")
@ -458,7 +459,7 @@ func TestGetHandlesNoKey(t *testing.T) {
} }
// Act // Act
response, err := m.mySQL.Get(request) response, err := m.mySQL.Get(context.TODO(), request)
// Assert // Assert
assert.NotNil(t, err, "returned error") assert.NotNil(t, err, "returned error")
@ -478,7 +479,7 @@ func TestGetHandlesGenericError(t *testing.T) {
} }
// Act // Act
response, err := m.mySQL.Get(request) response, err := m.mySQL.Get(context.TODO(), request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -499,7 +500,7 @@ func TestGetSucceeds(t *testing.T) {
} }
// Act // Act
response, err := m.mySQL.Get(request) response, err := m.mySQL.Get(context.TODO(), request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -517,7 +518,7 @@ func TestGetSucceeds(t *testing.T) {
} }
// Act // Act
response, err := m.mySQL.Get(request) response, err := m.mySQL.Get(context.TODO(), request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -711,7 +712,7 @@ func TestBulkGetReturnsNil(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
// Act // Act
supported, response, err := m.mySQL.BulkGet(nil) supported, response, err := m.mySQL.BulkGet(context.TODO(), nil)
// Assert // Assert
assert.Nil(t, err, `returned err`) assert.Nil(t, err, `returned err`)
@ -730,7 +731,7 @@ func TestMultiWithNoRequestsDoesNothing(t *testing.T) {
m.mock1.ExpectCommit() m.mock1.ExpectCommit()
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -750,7 +751,7 @@ func TestInvalidMultiAction(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -789,7 +790,7 @@ func TestValidSetRequest(t *testing.T) {
m.mock1.ExpectCommit() m.mock1.ExpectCommit()
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -810,7 +811,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -834,7 +835,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -858,7 +859,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -879,7 +880,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -902,7 +903,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })
@ -941,7 +942,7 @@ func TestMultiOperationOrder(t *testing.T) {
m.mock1.ExpectCommit() m.mock1.ExpectCommit()
// Act // Act
err := m.mySQL.Multi(&state.TransactionalStateRequest{ err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: ops, Operations: ops,
}) })

View File

@ -131,15 +131,15 @@ func (r *StateStore) Features() []state.Feature {
return r.features return r.features
} }
func (r *StateStore) Delete(req *state.DeleteRequest) error { func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r.logger.Debugf("Delete entry from OCI Object Storage State Store with key ", req.Key) r.logger.Debugf("Delete entry from OCI Object Storage State Store with key ", req.Key)
err := r.deleteDocument(req) err := r.deleteDocument(ctx, req)
return err return err
} }
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
r.logger.Debugf("Get from OCI Object Storage State Store with key ", req.Key) r.logger.Debugf("Get from OCI Object Storage State Store with key ", req.Key)
content, etag, err := r.readDocument((req)) content, etag, err := r.readDocument(ctx, req)
if err != nil { if err != nil {
r.logger.Debugf("error %s", err) r.logger.Debugf("error %s", err)
if err.Error() == "ObjectNotFound" { if err.Error() == "ObjectNotFound" {
@ -155,9 +155,9 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, err }, err
} }
func (r *StateStore) Set(req *state.SetRequest) error { func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r.logger.Debugf("saving %s to OCI Object Storage State Store", req.Key) r.logger.Debugf("saving %s to OCI Object Storage State Store", req.Key)
return r.writeDocument(req) return r.writeDocument(ctx, req)
} }
func (r *StateStore) Ping() error { func (r *StateStore) Ping() error {
@ -267,7 +267,7 @@ func getIdentityAuthenticationDetails(metadata map[string]string, meta *Metadata
} }
// functions that bridge from the Dapr State API to the OCI ObjectStorage Client. // functions that bridge from the Dapr State API to the OCI ObjectStorage Client.
func (r *StateStore) writeDocument(req *state.SetRequest) error { func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) error {
if len(req.Key) == 0 || req.Key == "" { if len(req.Key) == 0 || req.Key == "" {
return fmt.Errorf("key for value to set was missing from request") return fmt.Errorf("key for value to set was missing from request")
} }
@ -286,7 +286,6 @@ func (r *StateStore) writeDocument(req *state.SetRequest) error {
objectName := getFileName(req.Key) objectName := getFileName(req.Key)
content := r.marshal(req) content := r.marshal(req)
objectLength := int64(len(content)) objectLength := int64(len(content))
ctx := context.Background()
etag := req.ETag etag := req.ETag
if req.Options.Concurrency != state.FirstWrite { if req.Options.Concurrency != state.FirstWrite {
etag = nil etag = nil
@ -315,12 +314,11 @@ func (r *StateStore) convertTTLtoExpiryTime(req *state.SetRequest, metadata map[
return nil return nil
} }
func (r *StateStore) readDocument(req *state.GetRequest) ([]byte, *string, error) { func (r *StateStore) readDocument(ctx context.Context, req *state.GetRequest) ([]byte, *string, error) {
if len(req.Key) == 0 || req.Key == "" { if len(req.Key) == 0 || req.Key == "" {
return nil, nil, fmt.Errorf("key for value to get was missing from request") return nil, nil, fmt.Errorf("key for value to get was missing from request")
} }
objectName := getFileName(req.Key) objectName := getFileName(req.Key)
ctx := context.Background()
content, etag, meta, err := r.client.getObject(ctx, objectName) content, etag, meta, err := r.client.getObject(ctx, objectName)
if err != nil { if err != nil {
r.logger.Debugf("download file %s, err %s", req.Key, err) r.logger.Debugf("download file %s, err %s", req.Key, err)
@ -348,13 +346,12 @@ func (r *StateStore) pingBucket() error {
return nil return nil
} }
func (r *StateStore) deleteDocument(req *state.DeleteRequest) error { func (r *StateStore) deleteDocument(ctx context.Context, req *state.DeleteRequest) error {
if len(req.Key) == 0 || req.Key == "" { if len(req.Key) == 0 || req.Key == "" {
return fmt.Errorf("key for value to delete was missing from request") return fmt.Errorf("key for value to delete was missing from request")
} }
objectName := getFileName(req.Key) objectName := getFileName(req.Key)
ctx := context.Background()
etag := req.ETag etag := req.ETag
if req.Options.Concurrency != state.FirstWrite { if req.Options.Concurrency != state.FirstWrite {
etag = nil etag = nil

View File

@ -4,6 +4,7 @@ package objectstorage
// go test -v github.com/dapr/components-contrib/state/oci/objectstorage. // go test -v github.com/dapr/components-contrib/state/oci/objectstorage.
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -88,16 +89,16 @@ func testGet(t *testing.T, ociProperties map[string]string) {
t.Run("Get an non-existing key", func(t *testing.T) { t.Run("Get an non-existing key", func(t *testing.T) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "xyzq"}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "xyzq"})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Non-existing key must not be treated as error") assert.NoError(t, err, "Non-existing key must not be treated as error")
}) })
t.Run("Get an existing key", func(t *testing.T) { t.Run("Get an existing key", func(t *testing.T) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "test-key", Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-key", Value: []byte("test-value")})
assert.Nil(t, err) assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "test-key"}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-key"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.NotNil(t, *getResponse.ETag, "ETag should be set")
@ -105,9 +106,9 @@ func testGet(t *testing.T, ociProperties map[string]string) {
t.Run("Get an existing composed key", func(t *testing.T) { t.Run("Get an existing composed key", func(t *testing.T) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")})
assert.Nil(t, err) assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: "test-app||test-key"}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
}) })
@ -115,11 +116,11 @@ func testGet(t *testing.T, ociProperties map[string]string) {
testKey := "unexpired-ttl-test-key" testKey := "unexpired-ttl-test-key"
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "100", "ttlInSeconds": "100",
})}) })})
assert.Nil(t, err) assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set despite TTL setting") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set despite TTL setting")
}) })
@ -127,23 +128,23 @@ func testGet(t *testing.T, ociProperties map[string]string) {
testKey := "never-expiring-ttl-test-key" testKey := "never-expiring-ttl-test-key"
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "-1", "ttlInSeconds": "-1",
})}) })})
assert.Nil(t, err) assert.Nil(t, err)
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)")
}) })
t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) { t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "1", "ttlInSeconds": "1",
})}) })})
assert.Nil(t, err) assert.Nil(t, err)
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
getResponse, err := statestore.Get(&state.GetRequest{Key: "ttl-test-key"}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "ttl-test-key"})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Expired element must not be treated as error") assert.NoError(t, err, "Expired element must not be treated as error")
}) })
@ -156,7 +157,7 @@ func testSet(t *testing.T, ociProperties map[string]string) {
t.Run("Set without a key", func(t *testing.T) { t.Run("Set without a key", func(t *testing.T) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")})
assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error")
}) })
t.Run("Regular Set Operation", func(t *testing.T) { t.Run("Regular Set Operation", func(t *testing.T) {
@ -164,9 +165,9 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.NotNil(t, *getResponse.ETag, "ETag should be set")
@ -176,20 +177,20 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree")
getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.NotNil(t, *getResponse.ETag, "ETag should be set")
}) })
t.Run("Regular Set Operation with TTL", func(t *testing.T) { t.Run("Regular Set Operation with TTL", func(t *testing.T) {
testKey := "test-key-with-ttl" testKey := "test-key-with-ttl"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "500", "ttlInSeconds": "500",
})}) })})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "XXX", "ttlInSeconds": "XXX",
})}) })})
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
@ -200,25 +201,25 @@ func testSet(t *testing.T, ociProperties map[string]string) {
err := statestore.Init(meta) err := statestore.Init(meta)
assert.Nil(t, err) assert.Nil(t, err)
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, _ := statestore.Get(&state.GetRequest{Key: testKey}) getResponse, _ := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
etag := getResponse.ETag etag := getResponse.ETag
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.Nil(t, err, "Updating value with proper etag should go fine") assert.Nil(t, err, "Updating value with proper etag should go fine")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Updating value with the old etag should be refused") assert.NotNil(t, err, "Updating value with the old etag should be refused")
// retrieve the latest etag - assigned by the previous set operation. // retrieve the latest etag - assigned by the previous set operation.
getResponse, _ = statestore.Get(&state.GetRequest{Key: testKey}) getResponse, _ = statestore.Get(context.TODO(), &state.GetRequest{Key: testKey})
assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.NotNil(t, *getResponse.ETag, "ETag should be set")
etag = getResponse.ETag etag = getResponse.ETag
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.Nil(t, err, "Updating value with the latest etag should be accepted") assert.Nil(t, err, "Updating value with the latest etag should be accepted")
@ -232,7 +233,7 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
t.Run("Delete without a key", func(t *testing.T) { t.Run("Delete without a key", func(t *testing.T) {
err := s.Init(m) err := s.Init(m)
assert.Nil(t, err) assert.Nil(t, err)
err = s.Delete(&state.DeleteRequest{}) err = s.Delete(context.TODO(), &state.DeleteRequest{})
assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error")
}) })
t.Run("Regular Delete Operation", func(t *testing.T) { t.Run("Regular Delete Operation", func(t *testing.T) {
@ -240,9 +241,9 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m) err := s.Init(m)
assert.Nil(t, err) assert.Nil(t, err)
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
err = s.Delete(&state.DeleteRequest{Key: testKey}) err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree")
}) })
t.Run("Regular Delete Operation for composite key", func(t *testing.T) { t.Run("Regular Delete Operation for composite key", func(t *testing.T) {
@ -250,13 +251,13 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m) err := s.Init(m)
assert.Nil(t, err) assert.Nil(t, err)
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree")
err = s.Delete(&state.DeleteRequest{Key: testKey}) err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper composite key should be errorfree") assert.Nil(t, err, "Deleting an existing value with a proper composite key should be errorfree")
}) })
t.Run("Delete with an unknown key", func(t *testing.T) { t.Run("Delete with an unknown key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"})
assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found")
}) })
@ -265,18 +266,18 @@ func testDelete(t *testing.T, ociProperties map[string]string) {
err := s.Init(m) err := s.Init(m)
assert.Nil(t, err) assert.Nil(t, err)
// create document. // create document.
err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
getResponse, _ := s.Get(&state.GetRequest{Key: testKey}) getResponse, _ := s.Get(context.TODO(), &state.GetRequest{Key: testKey})
etag := getResponse.ETag etag := getResponse.ETag
incorrectETag := "someRandomETag" incorrectETag := "someRandomETag"
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented")
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{ err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.Nil(t, err, "Deleting value with proper etag should go fine") assert.Nil(t, err, "Deleting value with proper etag should go fine")

View File

@ -236,24 +236,24 @@ func TestGetWithMockClient(t *testing.T) {
s.client = mockClient s.client = mockClient
t.Parallel() t.Parallel()
t.Run("Test regular Get", func(t *testing.T) { t.Run("Test regular Get", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-key"}) getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-key"})
assert.True(t, mockClient.getIsCalled, "function Get should be invoked on the mockClient") assert.True(t, mockClient.getIsCalled, "function Get should be invoked on the mockClient")
assert.Equal(t, "Hello World", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "Hello World", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.NotNil(t, *getResponse.ETag, "ETag should be set")
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Test Get with composite key", func(t *testing.T) { t.Run("Test Get with composite key", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-app||test-key"}) getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"})
assert.Equal(t, "Hello Continent", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Equal(t, "Hello Continent", string(getResponse.Data), "Value retrieved should be equal to value set")
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Test Get with an unknown key", func(t *testing.T) { t.Run("Test Get with an unknown key", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "unknownKey"}) getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "unknownKey"})
assert.Nil(t, getResponse.Data, "No value should be retrieved for an unknown key") assert.Nil(t, getResponse.Data, "No value should be retrieved for an unknown key")
assert.Nil(t, err, "404", "Not finding an object because of unknown key should not result in an error") assert.Nil(t, err, "404", "Not finding an object because of unknown key should not result in an error")
}) })
t.Run("Test expired element (because of TTL) ", func(t *testing.T) { t.Run("Test expired element (because of TTL) ", func(t *testing.T) {
getResponse, err := s.Get(&state.GetRequest{Key: "test-expired-ttl-key"}) getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-expired-ttl-key"})
assert.Nil(t, getResponse.Data, "No value should be retrieved for an expired state element") assert.Nil(t, getResponse.Data, "No value should be retrieved for an expired state element")
assert.Nil(t, err, "Not returning an object because of expiration should not result in an error") assert.Nil(t, err, "Not returning an object because of expiration should not result in an error")
}) })
@ -289,28 +289,28 @@ func TestSetWithMockClient(t *testing.T) {
mockClient := &mockedObjectStoreClient{} mockClient := &mockedObjectStoreClient{}
statestore.client = mockClient statestore.client = mockClient
t.Run("Set without a key", func(t *testing.T) { t.Run("Set without a key", func(t *testing.T) {
err := statestore.Set(&state.SetRequest{Value: []byte("test-value")}) err := statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")})
assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error")
}) })
t.Run("Regular Set Operation", func(t *testing.T) { t.Run("Regular Set Operation", func(t *testing.T) {
testKey := "test-key" testKey := "test-key"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")})
assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.Nil(t, err, "Setting a value with a proper key should be errorfree")
assert.True(t, mockClient.putIsCalled, "function put should be invoked on the mockClient") assert.True(t, mockClient.putIsCalled, "function put should be invoked on the mockClient")
}) })
t.Run("Regular Set Operation with TTL", func(t *testing.T) { t.Run("Regular Set Operation with TTL", func(t *testing.T) {
testKey := "test-key" testKey := "test-key"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "5", "ttlInSeconds": "5",
})}) })})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "XXX", "ttlInSeconds": "XXX",
})}) })})
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{
"ttlInSeconds": "1", "ttlInSeconds": "1",
})}) })})
assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree")
@ -320,22 +320,22 @@ func TestSetWithMockClient(t *testing.T) {
incorrectETag := "notTheCorrectETag" incorrectETag := "notTheCorrectETag"
etag := "correctETag" etag := "correctETag"
err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{ err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Updating value with wrong etag should fail") assert.NotNil(t, err, "Updating value with wrong etag should fail")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.Nil(t, err, "Updating value with proper etag should go fine") assert.Nil(t, err, "Updating value with proper etag should go fine")
err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Updating value with concurrency policy at FirstWrite should fail when ETag is missing") assert.NotNil(t, err, "Updating value with concurrency policy at FirstWrite should fail when ETag is missing")
@ -348,34 +348,34 @@ func TestDeleteWithMockClient(t *testing.T) {
mockClient := &mockedObjectStoreClient{} mockClient := &mockedObjectStoreClient{}
s.client = mockClient s.client = mockClient
t.Run("Delete without a key", func(t *testing.T) { t.Run("Delete without a key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{}) err := s.Delete(context.TODO(), &state.DeleteRequest{})
assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error")
}) })
t.Run("Delete with an unknown key", func(t *testing.T) { t.Run("Delete with an unknown key", func(t *testing.T) {
err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"})
assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found")
}) })
t.Run("Regular Delete Operation", func(t *testing.T) { t.Run("Regular Delete Operation", func(t *testing.T) {
testKey := "test-key" testKey := "test-key"
err := s.Delete(&state.DeleteRequest{Key: testKey}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey})
assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree")
assert.True(t, mockClient.deleteIsCalled, "function delete should be invoked on the mockClient") assert.True(t, mockClient.deleteIsCalled, "function delete should be invoked on the mockClient")
}) })
t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) { t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) {
testKey := "etag-test-delete-key" testKey := "etag-test-delete-key"
incorrectETag := "notTheCorrectETag" incorrectETag := "notTheCorrectETag"
err := s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented")
etag := "correctETag" etag := "correctETag"
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{ err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.Nil(t, err, "Deleting value with proper etag should go fine") assert.Nil(t, err, "Deleting value with proper etag should go fine")
err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{ err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}}) }})
assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail")

View File

@ -14,6 +14,8 @@ limitations under the License.
package oracledatabase package oracledatabase
import ( import (
"context"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
) )
@ -21,9 +23,9 @@ import (
type dbAccess interface { type dbAccess interface {
Init(metadata state.Metadata) error Init(metadata state.Metadata) error
Ping() error Ping() error
Set(req *state.SetRequest) error Set(ctx context.Context, req *state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error Delete(ctx context.Context, req *state.DeleteRequest) error
ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error
Close() error // io.Closer. Close() error // io.Closer.
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package oracledatabase package oracledatabase
import ( import (
"context"
"fmt" "fmt"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
@ -59,38 +60,38 @@ func (o *OracleDatabase) Features() []state.Feature {
} }
// Delete removes an entity from the store. // Delete removes an entity from the store.
func (o *OracleDatabase) Delete(req *state.DeleteRequest) error { func (o *OracleDatabase) Delete(ctx context.Context, req *state.DeleteRequest) error {
return o.dbaccess.Delete(req) return o.dbaccess.Delete(ctx, req)
} }
// BulkDelete removes multiple entries from the store. // BulkDelete removes multiple entries from the store.
func (o *OracleDatabase) BulkDelete(req []state.DeleteRequest) error { func (o *OracleDatabase) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return o.dbaccess.ExecuteMulti(nil, req) return o.dbaccess.ExecuteMulti(ctx, nil, req)
} }
// Get returns an entity from store. // Get returns an entity from store.
func (o *OracleDatabase) Get(req *state.GetRequest) (*state.GetResponse, error) { func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return o.dbaccess.Get(req) return o.dbaccess.Get(ctx, req)
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (o *OracleDatabase) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance. // TODO: replace with ExecuteMulti for performance.
return false, nil, nil return false, nil, nil
} }
// Set adds/updates an entity on store. // Set adds/updates an entity on store.
func (o *OracleDatabase) Set(req *state.SetRequest) error { func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error {
return o.dbaccess.Set(req) return o.dbaccess.Set(ctx, req)
} }
// BulkSet adds/updates multiple entities on store. // BulkSet adds/updates multiple entities on store.
func (o *OracleDatabase) BulkSet(req []state.SetRequest) error { func (o *OracleDatabase) BulkSet(ctx context.Context, req []state.SetRequest) error {
return o.dbaccess.ExecuteMulti(req, nil) return o.dbaccess.ExecuteMulti(ctx, req, nil)
} }
// Multi handles multiple transactions. Implements TransactionalStore. // Multi handles multiple transactions. Implements TransactionalStore.
func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error { func (o *OracleDatabase) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
var deletes []state.DeleteRequest var deletes []state.DeleteRequest
var sets []state.SetRequest var sets []state.SetRequest
for _, req := range request.Operations { for _, req := range request.Operations {
@ -115,7 +116,7 @@ func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error {
} }
if len(sets) > 0 || len(deletes) > 0 { if len(sets) > 0 || len(deletes) > 0 {
return o.dbaccess.ExecuteMulti(sets, deletes) return o.dbaccess.ExecuteMulti(ctx, sets, deletes)
} }
return nil return nil

View File

@ -15,6 +15,7 @@ limitations under the License.
package oracledatabase package oracledatabase
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -223,7 +224,7 @@ func deleteItemThatDoesNotExist(t *testing.T, ods *OracleDatabase) {
deleteReq := &state.DeleteRequest{ deleteReq := &state.DeleteRequest{
Key: randomKey(), Key: randomKey(),
} }
err := ods.Delete(deleteReq) err := ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -242,7 +243,7 @@ func multiWithSetOnly(t *testing.T, ods *OracleDatabase) {
}) })
} }
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -272,7 +273,7 @@ func multiWithDeleteOnly(t *testing.T, ods *OracleDatabase) {
}) })
} }
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -315,7 +316,7 @@ func multiWithDeleteAndSet(t *testing.T, ods *OracleDatabase) {
}) })
} }
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -345,7 +346,7 @@ func deleteWithInvalidEtagFails(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err := ods.Delete(deleteReq) err := ods.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err, "Deleting an item with the wrong etag while enforcing FirstWrite policy should fail") assert.NotNil(t, err, "Deleting an item with the wrong etag while enforcing FirstWrite policy should fail")
} }
@ -353,7 +354,7 @@ func deleteWithNoKeyFails(t *testing.T, ods *OracleDatabase) {
deleteReq := &state.DeleteRequest{ deleteReq := &state.DeleteRequest{
Key: "", Key: "",
} }
err := ods.Delete(deleteReq) err := ods.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -371,7 +372,7 @@ func newItemWithEtagFails(t *testing.T, ods *OracleDatabase) {
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -401,7 +402,7 @@ func updateWithOldEtagFails(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -423,7 +424,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err, "Setting the item should be successful") assert.Nil(t, err, "Setting the item should be successful")
updateResponse, updatedItem := getItem(t, ods, key) updateResponse, updatedItem := getItem(t, ods, key)
assert.Equal(t, value, updatedItem) assert.Equal(t, value, updatedItem)
@ -439,7 +440,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) {
Concurrency: state.FirstWrite, Concurrency: state.FirstWrite,
}, },
} }
err = ods.Delete(deleteReq) err = ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err, "Deleting an item with the right etag while enforcing FirstWrite policy should succeed") assert.Nil(t, err, "Deleting an item with the right etag while enforcing FirstWrite policy should succeed")
// Item is not in the data store. // Item is not in the data store.
@ -465,7 +466,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac
Concurrency: state.LastWrite, Concurrency: state.LastWrite,
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err, "Setting the item should be successful") assert.Nil(t, err, "Setting the item should be successful")
_, updatedItem := getItem(t, ods, key) _, updatedItem := getItem(t, ods, key)
assert.Equal(t, value, updatedItem) assert.Equal(t, value, updatedItem)
@ -478,7 +479,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac
Concurrency: state.LastWrite, Concurrency: state.LastWrite,
}, },
} }
err = ods.Delete(deleteReq) err = ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, err, "Deleting an item with the wrong etag but not enforcing FirstWrite policy should succeed") assert.Nil(t, err, "Deleting an item with the wrong etag but not enforcing FirstWrite policy should succeed")
// Item is not in the data store. // Item is not in the data store.
@ -500,7 +501,7 @@ func getItemWithNoKey(t *testing.T, ods *OracleDatabase) {
Key: "", Key: "",
} }
response, getErr := ods.Get(getReq) response, getErr := ods.Get(context.TODO(), getReq)
assert.NotNil(t, getErr) assert.NotNil(t, getErr)
assert.Nil(t, response) assert.Nil(t, response)
} }
@ -548,7 +549,7 @@ func setTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) {
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
connectionString := getConnectionString() connectionString := getConnectionString()
if getWalletLocation() != "" { if getWalletLocation() != "" {
@ -580,10 +581,10 @@ func setNoTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "1000", "ttlInSeconds": "1000",
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
delete(setReq.Metadata, "ttlInSeconds") delete(setReq.Metadata, "ttlInSeconds")
err = ods.Set(setReq) err = ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
connectionString := getConnectionString() connectionString := getConnectionString()
if getWalletLocation() != "" { if getWalletLocation() != "" {
@ -614,11 +615,11 @@ func expiredStateCannotBeRead(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "1", "ttlInSeconds": "1",
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
time.Sleep(time.Second * time.Duration(2)) time.Sleep(time.Second * time.Duration(2))
getResponse, err := ods.Get(&state.GetRequest{Key: key}) getResponse, err := ods.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty")
assert.NoError(t, err, "Expired element must not be treated as error") assert.NoError(t, err, "Expired element must not be treated as error")
@ -639,7 +640,7 @@ func unexpiredStateCanBeRead(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "10000", "ttlInSeconds": "10000",
}, },
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
_, getValue := getItem(t, ods, key) _, getValue := getItem(t, ods, key)
assert.Equal(t, value.Color, getValue.Color, "Response must be as set") assert.Equal(t, value.Color, getValue.Color, "Response must be as set")
@ -653,7 +654,7 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) {
Key: "", Key: "",
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -702,7 +703,7 @@ func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "XX", "ttlInSeconds": "XX",
}), }),
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error")
} }
@ -714,7 +715,7 @@ func testSetItemWithNegativeTTL(t *testing.T, ods *OracleDatabase) {
"ttlInSeconds": "-10", "ttlInSeconds": "-10",
}), }),
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.NotNil(t, err, "Setting a value with a proper key and a negative (other than -1) TTL value should be produce an error") assert.NotNil(t, err, "Setting a value with a proper key and a negative (other than -1) TTL value should be produce an error")
} }
@ -731,7 +732,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) {
}, },
} }
err := ods.BulkSet(setReq) err := ods.BulkSet(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key)) assert.True(t, storeItemExists(t, setReq[1].Key))
@ -745,7 +746,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) {
}, },
} }
err = ods.BulkDelete(deleteReq) err = ods.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key)) assert.False(t, storeItemExists(t, setReq[1].Key))
@ -812,7 +813,7 @@ func setItem(t *testing.T, ods *OracleDatabase, key string, value interface{}, e
Options: setOptions, Options: setOptions,
} }
err := ods.Set(setReq) err := ods.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
itemExists := storeItemExists(t, key) itemExists := storeItemExists(t, key)
assert.True(t, itemExists, "Item should exist after set has been executed ") assert.True(t, itemExists, "Item should exist after set has been executed ")
@ -824,7 +825,7 @@ func getItem(t *testing.T, ods *OracleDatabase, key string) (*state.GetResponse,
Options: state.GetStateOption{}, Options: state.GetStateOption{},
} }
response, getErr := ods.Get(getReq) response, getErr := ods.Get(context.TODO(), getReq)
assert.Nil(t, getErr) assert.Nil(t, getErr)
assert.NotNil(t, response) assert.NotNil(t, response)
outputObject := &fakeItem{} outputObject := &fakeItem{}
@ -840,7 +841,7 @@ func deleteItem(t *testing.T, ods *OracleDatabase, key string, etag *string) {
Options: state.DeleteStateOption{}, Options: state.DeleteStateOption{},
} }
deleteErr := ods.Delete(deleteReq) deleteErr := ods.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr) assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key), "item should no longer exist after delete has been performed") assert.False(t, storeItemExists(t, key), "item should no longer exist after delete has been performed")
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
package oracledatabase package oracledatabase
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -48,23 +49,23 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil return nil
} }
func (m *fakeDBaccess) Set(req *state.SetRequest) error { func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true m.setExecuted = true
return nil return nil
} }
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true m.getExecuted = true
return nil, nil return nil, nil
} }
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error {
return nil return nil
} }
@ -84,7 +85,7 @@ func TestMultiWithNoRequestsReturnsNil(t *testing.T) {
t.Parallel() t.Parallel()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -100,7 +101,7 @@ func TestInvalidMultiAction(t *testing.T) {
}) })
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
@ -116,7 +117,7 @@ func TestValidSetRequest(t *testing.T) {
}) })
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -132,7 +133,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
}) })
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
@ -148,7 +149,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
}) })
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -164,7 +165,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
}) })
ods := createOracleDatabase(t) ods := createOracleDatabase(t)
err := ods.Multi(&state.TransactionalStateRequest{ err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@ -14,6 +14,7 @@ limitations under the License.
package oracledatabase package oracledatabase
import ( import (
"context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -96,8 +97,8 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
} }
// Set makes an insert or update to the database. // Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error { func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(o.setValue, req) return state.SetWithOptions(ctx, o.setValue, req)
} }
func parseTTL(requestMetadata map[string]string) (*int, error) { func parseTTL(requestMetadata map[string]string) (*int, error) {
@ -115,7 +116,7 @@ func parseTTL(requestMetadata map[string]string) (*int, error) {
} }
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { func (o *oracleDatabaseAccess) setValue(ctx context.Context, req *state.SetRequest) error {
o.logger.Debug("Setting state value in OracleDatabase") o.logger.Debug("Setting state value in OracleDatabase")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
@ -182,14 +183,14 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `, WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `,
tableName) tableName)
result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds) result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
} else { } else {
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided. // when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
updateStatement := fmt.Sprintf( updateStatement := fmt.Sprintf(
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag `UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
WHERE key = :key AND etag = :etag`, WHERE key = :key AND etag = :etag`,
tableName) tableName)
result, err = tx.Exec(updateStatement, value, binaryYN, etag, req.Key, *req.ETag) result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
} }
if err != nil { if err != nil {
if req.ETag != nil && *req.ETag != "" { if req.ETag != nil && *req.ETag != "" {
@ -214,7 +215,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
} }
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
o.logger.Debug("Getting state value from OracleDatabase") o.logger.Debug("Getting state value from OracleDatabase")
if req.Key == "" { if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation") return nil, fmt.Errorf("missing key in get operation")
@ -222,7 +223,7 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
var value string var value string
var binaryYN string var binaryYN string
var etag string var etag string
err := o.db.QueryRow(fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag) err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -253,12 +254,12 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
} }
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error { func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(o.deleteValue, req) return state.DeleteWithOptions(ctx, o.deleteValue, req)
} }
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { func (o *oracleDatabaseAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
o.logger.Debug("Deleting state value from OracleDatabase") o.logger.Debug("Deleting state value from OracleDatabase")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return fmt.Errorf("missing key in delete operation")
@ -280,9 +281,9 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
} }
// QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied?? // QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied??
if req.Options.Concurrency != state.FirstWrite { if req.Options.Concurrency != state.FirstWrite {
result, err = tx.Exec("DELETE FROM state WHERE key = :key", req.Key) result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key)
} else { } else {
result, err = tx.Exec("DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag) result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
} }
if err != nil { if err != nil {
if o.tx == nil { // not joining a preexisting transaction. if o.tx == nil { // not joining a preexisting transaction.
@ -303,7 +304,7 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
return nil return nil
} }
func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { func (o *oracleDatabaseAccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error {
o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction") o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction")
tx, err := o.db.Begin() tx, err := o.db.Begin()
if err != nil { if err != nil {
@ -313,7 +314,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
if len(deletes) > 0 { if len(deletes) > 0 {
for _, d := range deletes { for _, d := range deletes {
da := d // Fix for gosec G601: Implicit memory aliasing in for looo. da := d // Fix for gosec G601: Implicit memory aliasing in for looo.
err = o.Delete(&da) err = o.Delete(ctx, &da)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -323,7 +324,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
if len(sets) > 0 { if len(sets) > 0 {
for _, s := range sets { for _, s := range sets {
sa := s // Fix for gosec G601: Implicit memory aliasing in for looo. sa := s // Fix for gosec G601: Implicit memory aliasing in for looo.
err = o.Set(&sa) err = o.Set(ctx, &sa)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err

View File

@ -14,18 +14,20 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
) )
// dbAccess is a private interface which enables unit testing of PostgreSQL. // dbAccess is a private interface which enables unit testing of PostgreSQL.
type dbAccess interface { type dbAccess interface {
Init(metadata state.Metadata) error Init(metadata state.Metadata) error
Set(req *state.SetRequest) error Set(ctx context.Context, req *state.SetRequest) error
BulkSet(req []state.SetRequest) error BulkSet(ctx context.Context, req []state.SetRequest) error
Get(req *state.GetRequest) (*state.GetResponse, error) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(req *state.DeleteRequest) error Delete(ctx context.Context, req *state.DeleteRequest) error
BulkDelete(req []state.DeleteRequest) error BulkDelete(ctx context.Context, req []state.DeleteRequest) error
ExecuteMulti(req *state.TransactionalStateRequest) error ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error
Query(req *state.QueryRequest) (*state.QueryResponse, error) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
Close() error // io.Closer Close() error // io.Closer
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -98,12 +99,12 @@ func (p *postgresDBAccess) Init(metadata state.Metadata) error {
} }
// Set makes an insert or update to the database. // Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error { func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req) return state.SetWithOptions(ctx, p.setValue, req)
} }
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *postgresDBAccess) setValue(req *state.SetRequest) error { func (p *postgresDBAccess) setValue(ctx context.Context, req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL") p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
@ -134,7 +135,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names. // Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution. // Other parameters use sql.DB parameter substitution.
if req.ETag == nil { if req.ETag == nil {
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3) `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`, ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
tableName), req.Key, value, isBinary) tableName), req.Key, value, isBinary)
@ -148,7 +149,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
etag := uint32(etag64) etag := uint32(etag64)
// When an etag is provided do an update - no insert // When an etag is provided do an update - no insert
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW() `UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW()
WHERE key = $3 AND xmin = $4;`, WHERE key = $3 AND xmin = $4;`,
tableName), value, isBinary, req.Key, etag) tableName), value, isBinary, req.Key, etag)
@ -174,7 +175,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
return nil return nil
} }
func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request") p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin() tx, err := p.db.Begin()
if err != nil { if err != nil {
@ -184,7 +185,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, s := range req { for _, s := range req {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(&sa) err = p.Set(ctx, &sa)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -199,7 +200,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
} }
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL") p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" { if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation") return nil, fmt.Errorf("missing key in get operation")
@ -208,7 +209,7 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
var value string var value string
var isBinary bool var isBinary bool
var etag int var etag int
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -245,12 +246,12 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
} }
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error { func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req) return state.DeleteWithOptions(ctx, p.deleteValue, req)
} }
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { func (p *postgresDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from PostgreSQL") p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return fmt.Errorf("missing key in delete operation")
@ -260,7 +261,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
var err error var err error
if req.ETag == nil { if req.ETag == nil {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
} else { } else {
// Convert req.ETag to uint32 for postgres XID compatibility // Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64 var etag64 uint64
@ -270,7 +271,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
} }
etag := uint32(etag64) etag := uint32(etag64)
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag) result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
} }
if err != nil { if err != nil {
@ -289,7 +290,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
return nil return nil
} }
func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request") p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin() tx, err := p.db.Begin()
if err != nil { if err != nil {
@ -299,7 +300,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, d := range req { for _, d := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop. da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da) err = p.Delete(ctx, &da)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -313,7 +314,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
return err return err
} }
func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction") p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin() tx, err := p.db.Begin()
@ -332,7 +333,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
return err return err
} }
err = p.Set(&setReq) err = p.Set(ctx, &setReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -347,7 +348,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
return err return err
} }
err = p.Delete(&delReq) err = p.Delete(ctx, &delReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -365,7 +366,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
} }
// Query executes a query against store. // Query executes a query against store.
func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from PostgreSQL") p.logger.Debug("Getting query value from PostgreSQL")
q := &Query{ q := &Query{
query: "", query: "",
@ -375,7 +376,7 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse,
if err := qbuilder.BuildQuery(&req.Query); err != nil { if err := qbuilder.BuildQuery(&req.Query); err != nil {
return &state.QueryResponse{}, err return &state.QueryResponse{}, err
} }
data, token, err := q.execute(p.logger, p.db) data, token, err := q.execute(ctx, p.logger, p.db)
if err != nil { if err != nil {
return &state.QueryResponse{}, err return &state.QueryResponse{}, err
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"testing" "testing"
@ -110,7 +111,7 @@ func TestMultiWithNoRequests(t *testing.T) {
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -134,7 +135,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -159,7 +160,7 @@ func TestValidSetRequest(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -183,7 +184,7 @@ func TestInvalidMultiSetRequest(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -207,7 +208,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -232,7 +233,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -256,7 +257,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -280,7 +281,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -312,7 +313,7 @@ func TestMultiOperationOrder(t *testing.T) {
) )
// Act // Act
err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
@ -335,7 +336,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.BulkSet(sets) err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -357,7 +358,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.BulkSet(sets) err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -380,7 +381,7 @@ func TestValidBulkSet(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.BulkSet(sets) err := m.pgDba.BulkSet(context.TODO(), sets)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -401,7 +402,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.BulkDelete(deletes) err := m.pgDba.BulkDelete(context.TODO(), deletes)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -423,7 +424,7 @@ func TestValidBulkDelete(t *testing.T) {
}) })
// Act // Act
err := m.pgDba.BulkDelete(deletes) err := m.pgDba.BulkDelete(context.TODO(), deletes)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -14,6 +14,8 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
@ -53,44 +55,44 @@ func (p *PostgreSQL) Features() []state.Feature {
} }
// Delete removes an entity from the store. // Delete removes an entity from the store.
func (p *PostgreSQL) Delete(req *state.DeleteRequest) error { func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
return p.dbaccess.Delete(req) return p.dbaccess.Delete(ctx, req)
} }
// BulkDelete removes multiple entries from the store. // BulkDelete removes multiple entries from the store.
func (p *PostgreSQL) BulkDelete(req []state.DeleteRequest) error { func (p *PostgreSQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return p.dbaccess.BulkDelete(req) return p.dbaccess.BulkDelete(ctx, req)
} }
// Get returns an entity from store. // Get returns an entity from store.
func (p *PostgreSQL) Get(req *state.GetRequest) (*state.GetResponse, error) { func (p *PostgreSQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
return p.dbaccess.Get(req) return p.dbaccess.Get(ctx, req)
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (p *PostgreSQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (p *PostgreSQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with ExecuteMulti for performance // TODO: replace with ExecuteMulti for performance
return false, nil, nil return false, nil, nil
} }
// Set adds/updates an entity on store. // Set adds/updates an entity on store.
func (p *PostgreSQL) Set(req *state.SetRequest) error { func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error {
return p.dbaccess.Set(req) return p.dbaccess.Set(ctx, req)
} }
// BulkSet adds/updates multiple entities on store. // BulkSet adds/updates multiple entities on store.
func (p *PostgreSQL) BulkSet(req []state.SetRequest) error { func (p *PostgreSQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
return p.dbaccess.BulkSet(req) return p.dbaccess.BulkSet(ctx, req)
} }
// Multi handles multiple transactions. Implements TransactionalStore. // Multi handles multiple transactions. Implements TransactionalStore.
func (p *PostgreSQL) Multi(request *state.TransactionalStateRequest) error { func (p *PostgreSQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
return p.dbaccess.ExecuteMulti(request) return p.dbaccess.ExecuteMulti(ctx, request)
} }
// Query executes a query against store. // Query executes a query against store.
func (p *PostgreSQL) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (p *PostgreSQL) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return p.dbaccess.Query(req) return p.dbaccess.Query(ctx, req)
} }
// Close implements io.Closer. // Close implements io.Closer.

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -192,7 +193,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
deleteReq := &state.DeleteRequest{ deleteReq := &state.DeleteRequest{
Key: randomKey(), Key: randomKey(),
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -211,7 +212,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -241,7 +242,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -284,7 +285,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
}) })
} }
err := pgs.Multi(&state.TransactionalStateRequest{ err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -311,7 +312,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *PostgreSQL) {
Key: key, Key: key,
ETag: &etag, ETag: &etag,
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -319,7 +320,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *PostgreSQL) {
deleteReq := &state.DeleteRequest{ deleteReq := &state.DeleteRequest{
Key: "", Key: "",
} }
err := pgs.Delete(deleteReq) err := pgs.Delete(context.TODO(), deleteReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -334,7 +335,7 @@ func newItemWithEtagFails(t *testing.T, pgs *PostgreSQL) {
Value: value, Value: value,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -360,7 +361,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *PostgreSQL) {
ETag: originalEtag, ETag: originalEtag,
Value: newValue, Value: newValue,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -402,7 +403,7 @@ func getItemWithNoKey(t *testing.T, pgs *PostgreSQL) {
Key: "", Key: "",
} }
response, getErr := pgs.Get(getReq) response, getErr := pgs.Get(context.TODO(), getReq)
assert.NotNil(t, getErr) assert.NotNil(t, getErr)
assert.Nil(t, response) assert.Nil(t, response)
} }
@ -433,7 +434,7 @@ func setItemWithNoKey(t *testing.T, pgs *PostgreSQL) {
Key: "", Key: "",
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -450,7 +451,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
}, },
} }
err := pgs.BulkSet(setReq) err := pgs.BulkSet(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key)) assert.True(t, storeItemExists(t, setReq[1].Key))
@ -464,7 +465,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
}, },
} }
err = pgs.BulkDelete(deleteReq) err = pgs.BulkDelete(context.TODO(), deleteReq)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key)) assert.False(t, storeItemExists(t, setReq[1].Key))
@ -521,7 +522,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
Value: value, Value: value,
} }
err := pgs.Set(setReq) err := pgs.Set(context.TODO(), setReq)
assert.Nil(t, err) assert.Nil(t, err)
itemExists := storeItemExists(t, key) itemExists := storeItemExists(t, key)
assert.True(t, itemExists) assert.True(t, itemExists)
@ -533,7 +534,7 @@ func getItem(t *testing.T, pgs *PostgreSQL, key string) (*state.GetResponse, *fa
Options: state.GetStateOption{}, Options: state.GetStateOption{},
} }
response, getErr := pgs.Get(getReq) response, getErr := pgs.Get(context.TODO(), getReq)
assert.Nil(t, getErr) assert.Nil(t, getErr)
assert.NotNil(t, response) assert.NotNil(t, response)
outputObject := &fakeItem{} outputObject := &fakeItem{}
@ -549,7 +550,7 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
Options: state.DeleteStateOption{}, Options: state.DeleteStateOption{},
} }
deleteErr := pgs.Delete(deleteReq) deleteErr := pgs.Delete(context.TODO(), deleteReq)
assert.Nil(t, deleteErr) assert.Nil(t, deleteErr)
assert.False(t, storeItemExists(t, key)) assert.False(t, storeItemExists(t, key))
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strconv" "strconv"
@ -139,8 +140,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
return nil return nil
} }
func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.Query(q.query, q.params...) rows, err := db.QueryContext(ctx, q.query, q.params...)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
package postgresql package postgresql
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -43,37 +44,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error {
return nil return nil
} }
func (m *fakeDBaccess) Set(req *state.SetRequest) error { func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error {
m.setExecuted = true m.setExecuted = true
return nil return nil
} }
func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.getExecuted = true m.getExecuted = true
return nil, nil return nil, nil
} }
func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
m.deleteExecuted = true m.deleteExecuted = true
return nil return nil
} }
func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error {
return nil return nil
} }
func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
return nil, nil return nil, nil
} }

View File

@ -195,7 +195,7 @@ func (r *StateStore) parseConnectedSlaves(res string) int {
return 0 return 0
} }
func (r *StateStore) deleteValue(req *state.DeleteRequest) error { func (r *StateStore) deleteValue(ctx context.Context, req *state.DeleteRequest) error {
if req.ETag == nil { if req.ETag == nil {
etag := "0" etag := "0"
req.ETag = &etag req.ETag = &etag
@ -207,7 +207,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
} else { } else {
delQuery = delDefaultQuery delQuery = delDefaultQuery
} }
_, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() _, err := r.client.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
if err != nil { if err != nil {
return state.NewETagError(state.ETagMismatch, err) return state.NewETagError(state.ETagMismatch, err)
} }
@ -216,17 +216,17 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error { func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
} }
return state.DeleteWithOptions(r.deleteValue, req) return state.DeleteWithOptions(ctx, r.deleteValue, req)
} }
func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) directGet(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "GET", req.Key).Result() res, err := r.client.Do(ctx, "GET", req.Key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,10 +242,10 @@ func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error
}, nil }, nil
} }
func (r *StateStore) getDefault(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) getDefault(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "HGETALL", req.Key).Result() // Prefer values with ETags res, err := r.client.Do(ctx, "HGETALL", req.Key).Result() // Prefer values with ETags
if err != nil { if err != nil {
return r.directGet(req) // Falls back to original get for backward compats. return r.directGet(ctx, req) // Falls back to original get for backward compats.
} }
if res == nil { if res == nil {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
@ -304,12 +304,12 @@ func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error)
} }
// Get retrieves state from redis with a key. // Get retrieves state from redis with a key.
func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType { if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType {
return r.getJSON(req) return r.getJSON(req)
} }
return r.getDefault(req) return r.getDefault(ctx, req)
} }
type jsonEntry struct { type jsonEntry struct {
@ -317,7 +317,7 @@ type jsonEntry struct {
Version *int `json:"version,omitempty"` Version *int `json:"version,omitempty"`
} }
func (r *StateStore) setValue(req *state.SetRequest) error { func (r *StateStore) setValue(ctx context.Context, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -350,7 +350,7 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
bt, _ = utils.Marshal(req.Value, r.json.Marshal) bt, _ = utils.Marshal(req.Value, r.json.Marshal)
} }
err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err() err = r.client.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err()
if err != nil { if err != nil {
if req.ETag != nil { if req.ETag != nil {
return state.NewETagError(state.ETagMismatch, err) return state.NewETagError(state.ETagMismatch, err)
@ -360,21 +360,21 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
} }
if ttl != nil && *ttl > 0 { if ttl != nil && *ttl > 0 {
_, err = r.client.Do(r.ctx, "EXPIRE", req.Key, *ttl).Result() _, err = r.client.Do(ctx, "EXPIRE", req.Key, *ttl).Result()
if err != nil { if err != nil {
return fmt.Errorf("failed to set key %s ttl: %s", req.Key, err) return fmt.Errorf("failed to set key %s ttl: %s", req.Key, err)
} }
} }
if ttl != nil && *ttl <= 0 { if ttl != nil && *ttl <= 0 {
_, err = r.client.Do(r.ctx, "PERSIST", req.Key).Result() _, err = r.client.Do(ctx, "PERSIST", req.Key).Result()
if err != nil { if err != nil {
return fmt.Errorf("failed to persist key %s: %s", req.Key, err) return fmt.Errorf("failed to persist key %s: %s", req.Key, err)
} }
} }
if req.Options.Consistency == state.Strong && r.replicas > 0 { if req.Options.Consistency == state.Strong && r.replicas > 0 {
_, err = r.client.Do(r.ctx, "WAIT", r.replicas, 1000).Result() _, err = r.client.Do(ctx, "WAIT", r.replicas, 1000).Result()
if err != nil { if err != nil {
return fmt.Errorf("redis waiting for %v replicas to acknowledge write, err: %s", r.replicas, err.Error()) return fmt.Errorf("redis waiting for %v replicas to acknowledge write, err: %s", r.replicas, err.Error())
} }
@ -384,12 +384,12 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
} }
// Set saves state into redis. // Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error { func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
return state.SetWithOptions(r.setValue, req) return state.SetWithOptions(ctx, r.setValue, req)
} }
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { func (r *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
var setQuery, delQuery string var setQuery, delQuery string
var isJSON bool var isJSON bool
if contentType, ok := request.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType { if contentType, ok := request.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType {
@ -423,12 +423,12 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
} else { } else {
bt, _ = utils.Marshal(req.Value, r.json.Marshal) bt, _ = utils.Marshal(req.Value, r.json.Marshal)
} }
pipe.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt) pipe.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt)
if ttl != nil && *ttl > 0 { if ttl != nil && *ttl > 0 {
pipe.Do(r.ctx, "EXPIRE", req.Key, *ttl) pipe.Do(ctx, "EXPIRE", req.Key, *ttl)
} }
if ttl != nil && *ttl <= 0 { if ttl != nil && *ttl <= 0 {
pipe.Do(r.ctx, "PERSIST", req.Key) pipe.Do(ctx, "PERSIST", req.Key)
} }
} else if o.Operation == state.Delete { } else if o.Operation == state.Delete {
req := o.Request.(state.DeleteRequest) req := o.Request.(state.DeleteRequest)
@ -436,11 +436,11 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
etag := "0" etag := "0"
req.ETag = &etag req.ETag = &etag
} }
pipe.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag) pipe.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag)
} }
} }
_, err := pipe.Exec(r.ctx) _, err := pipe.Exec(ctx)
return err return err
} }
@ -514,7 +514,7 @@ func (r *StateStore) parseTTL(req *state.SetRequest) (*int, error) {
} }
// Query executes a query against store. // Query executes a query against store.
func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) { func (r *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
indexName, ok := daprmetadata.TryGetQueryIndexName(req.Metadata) indexName, ok := daprmetadata.TryGetQueryIndexName(req.Metadata)
if !ok { if !ok {
return nil, fmt.Errorf("query index not found") return nil, fmt.Errorf("query index not found")
@ -529,7 +529,7 @@ func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error
if err := qbuilder.BuildQuery(&req.Query); err != nil { if err := qbuilder.BuildQuery(&req.Query); err != nil {
return &state.QueryResponse{}, err return &state.QueryResponse{}, err
} }
data, token, err := q.execute(r.ctx, r.client) data, token, err := q.execute(ctx, r.client)
if err != nil { if err != nil {
return &state.QueryResponse{}, err return &state.QueryResponse{}, err
} }

View File

@ -206,7 +206,7 @@ func TestTransactionalUpsert(t *testing.T) {
} }
ss.ctx, ss.cancel = context.WithCancel(context.Background()) ss.ctx, ss.cancel = context.WithCancel(context.Background())
err := ss.Multi(&state.TransactionalStateRequest{ err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{ {
Operation: state.Upsert, Operation: state.Upsert,
@ -273,13 +273,13 @@ func TestTransactionalDelete(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background()) ss.ctx, ss.cancel = context.WithCancel(context.Background())
// Insert a record first. // Insert a record first.
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon", Key: "weapon",
Value: "deathstar", Value: "deathstar",
}) })
etag := "1" etag := "1"
err := ss.Multi(&state.TransactionalStateRequest{ err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{{ Operations: []state.TransactionalStateOperation{{
Operation: state.Delete, Operation: state.Delete,
Request: state.DeleteRequest{ Request: state.DeleteRequest{
@ -331,7 +331,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background()) ss.ctx, ss.cancel = context.WithCancel(context.Background())
t.Run("TTL: Only global specified", func(t *testing.T) { t.Run("TTL: Only global specified", func(t *testing.T) {
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100", Key: "weapon100",
Value: "deathstar100", Value: "deathstar100",
}) })
@ -342,7 +342,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
t.Run("TTL: Global and Request specified", func(t *testing.T) { t.Run("TTL: Global and Request specified", func(t *testing.T) {
requestTTL := 200 requestTTL := 200
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100", Key: "weapon100",
Value: "deathstar100", Value: "deathstar100",
Metadata: map[string]string{ Metadata: map[string]string{
@ -355,7 +355,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) {
}) })
t.Run("TTL: Global and Request specified", func(t *testing.T) { t.Run("TTL: Global and Request specified", func(t *testing.T) {
err := ss.Multi(&state.TransactionalStateRequest{ err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{ {
Operation: state.Upsert, Operation: state.Upsert,
@ -424,7 +424,7 @@ func TestSetRequestWithTTL(t *testing.T) {
t.Run("TTL specified", func(t *testing.T) { t.Run("TTL specified", func(t *testing.T) {
ttlInSeconds := 100 ttlInSeconds := 100
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100", Key: "weapon100",
Value: "deathstar100", Value: "deathstar100",
Metadata: map[string]string{ Metadata: map[string]string{
@ -438,7 +438,7 @@ func TestSetRequestWithTTL(t *testing.T) {
}) })
t.Run("TTL not specified", func(t *testing.T) { t.Run("TTL not specified", func(t *testing.T) {
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon200", Key: "weapon200",
Value: "deathstar200", Value: "deathstar200",
}) })
@ -449,7 +449,7 @@ func TestSetRequestWithTTL(t *testing.T) {
}) })
t.Run("TTL Changed for Existing Key", func(t *testing.T) { t.Run("TTL Changed for Existing Key", func(t *testing.T) {
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300", Key: "weapon300",
Value: "deathstar300", Value: "deathstar300",
}) })
@ -458,7 +458,7 @@ func TestSetRequestWithTTL(t *testing.T) {
// make the key no longer persistent // make the key no longer persistent
ttlInSeconds := 123 ttlInSeconds := 123
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300", Key: "weapon300",
Value: "deathstar300", Value: "deathstar300",
Metadata: map[string]string{ Metadata: map[string]string{
@ -469,7 +469,7 @@ func TestSetRequestWithTTL(t *testing.T) {
assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl) assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl)
// make the key persistent again // make the key persistent again
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon300", Key: "weapon300",
Value: "deathstar301", Value: "deathstar301",
Metadata: map[string]string{ Metadata: map[string]string{
@ -493,12 +493,12 @@ func TestTransactionalDeleteNoEtag(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background()) ss.ctx, ss.cancel = context.WithCancel(context.Background())
// Insert a record first. // Insert a record first.
ss.Set(&state.SetRequest{ ss.Set(context.TODO(), &state.SetRequest{
Key: "weapon100", Key: "weapon100",
Value: "deathstar100", Value: "deathstar100",
}) })
err := ss.Multi(&state.TransactionalStateRequest{ err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{{ Operations: []state.TransactionalStateOperation{{
Operation: state.Delete, Operation: state.Delete,
Request: state.DeleteRequest{ Request: state.DeleteRequest{

View File

@ -14,6 +14,7 @@ limitations under the License.
package state package state
import ( import (
"context"
"fmt" "fmt"
) )
@ -68,11 +69,11 @@ func validateConsistencyOption(c string) error {
} }
// SetWithOptions handles SetRequest with request options. // SetWithOptions handles SetRequest with request options.
func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error { func SetWithOptions(ctx context.Context, method func(ctx context.Context, req *SetRequest) error, req *SetRequest) error {
return method(req) return method(ctx, req)
} }
// DeleteWithOptions handles DeleteRequest with options. // DeleteWithOptions handles DeleteRequest with options.
func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error { func DeleteWithOptions(ctx context.Context, method func(ctx context.Context, req *DeleteRequest) error, req *DeleteRequest) error {
return method(req) return method(ctx, req)
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package state package state
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -23,7 +24,7 @@ import (
func TestSetRequestWithOptions(t *testing.T) { func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with default options", func(t *testing.T) { t.Run("set with default options", func(t *testing.T) {
counter := 0 counter := 0
SetWithOptions(func(req *SetRequest) error { SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error {
counter++ counter++
return nil return nil
@ -33,7 +34,7 @@ func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with no explicit options", func(t *testing.T) { t.Run("set with no explicit options", func(t *testing.T) {
counter := 0 counter := 0
SetWithOptions(func(req *SetRequest) error { SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error {
counter++ counter++
return nil return nil

View File

@ -14,6 +14,7 @@ limitations under the License.
package rethinkdb package rethinkdb
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"strconv" "strconv"
@ -147,7 +148,7 @@ func tableExists(arr []string, table string) bool {
} }
// Get retrieves a RethinkDB KV item. // Get retrieves a RethinkDB KV item.
func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) { func (s *RethinkDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if req == nil || req.Key == "" { if req == nil || req.Key == "" {
return nil, errors.New("invalid state request, missing key") return nil, errors.New("invalid state request, missing key")
} }
@ -187,22 +188,22 @@ func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (s *RethinkDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (s *RethinkDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with bulk get for performance // TODO: replace with bulk get for performance
return false, nil, nil return false, nil, nil
} }
// Set saves a state KV item. // Set saves a state KV item.
func (s *RethinkDB) Set(req *state.SetRequest) error { func (s *RethinkDB) Set(ctx context.Context, req *state.SetRequest) error {
if req == nil || req.Key == "" || req.Value == nil { if req == nil || req.Key == "" || req.Value == nil {
return errors.New("invalid state request, key and value required") return errors.New("invalid state request, key and value required")
} }
return s.BulkSet([]state.SetRequest{*req}) return s.BulkSet(ctx, []state.SetRequest{*req})
} }
// BulkSet performs a bulk save operation. // BulkSet performs a bulk save operation.
func (s *RethinkDB) BulkSet(req []state.SetRequest) error { func (s *RethinkDB) BulkSet(ctx context.Context, req []state.SetRequest) error {
docs := make([]*stateRecord, len(req)) docs := make([]*stateRecord, len(req))
for i, v := range req { for i, v := range req {
var etag string var etag string
@ -257,16 +258,16 @@ func (s *RethinkDB) archive(changes []r.ChangeResponse) error {
} }
// Delete performes a RethinkDB KV delete operation. // Delete performes a RethinkDB KV delete operation.
func (s *RethinkDB) Delete(req *state.DeleteRequest) error { func (s *RethinkDB) Delete(ctx context.Context, req *state.DeleteRequest) error {
if req == nil || req.Key == "" { if req == nil || req.Key == "" {
return errors.New("invalid request, missing key") return errors.New("invalid request, missing key")
} }
return s.BulkDelete([]state.DeleteRequest{*req}) return s.BulkDelete(ctx, []state.DeleteRequest{*req})
} }
// BulkDelete performs a bulk delete operation. // BulkDelete performs a bulk delete operation.
func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error { func (s *RethinkDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
list := make([]string, 0) list := make([]string, 0)
for _, d := range req { for _, d := range req {
list = append(list, d.Key) list = append(list, d.Key)
@ -282,7 +283,7 @@ func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error {
} }
// Multi performs multiple operations. // Multi performs multiple operations.
func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error { func (s *RethinkDB) Multi(ctx context.Context, req *state.TransactionalStateRequest) error {
upserts := make([]state.SetRequest, 0) upserts := make([]state.SetRequest, 0)
deletes := make([]state.DeleteRequest, 0) deletes := make([]state.DeleteRequest, 0)
@ -306,11 +307,11 @@ func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error {
} }
// best effort, no transacts supported // best effort, no transacts supported
if err := s.BulkSet(upserts); err != nil { if err := s.BulkSet(ctx, upserts); err != nil {
return errors.Wrap(err, "error saving records to the database") return errors.Wrap(err, "error saving records to the database")
} }
if err := s.BulkDelete(deletes); err != nil { if err := s.BulkDelete(ctx, deletes); err != nil {
return errors.Wrap(err, "error deleting records to the database") return errors.Wrap(err, "error deleting records to the database")
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package rethinkdb package rethinkdb
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -86,12 +87,12 @@ func TestRethinkDBStateStore(t *testing.T) {
d := &testObj{F1: "test", F2: 1, F3: time.Now().UTC()} d := &testObj{F1: "test", F2: 1, F3: time.Now().UTC()}
k := fmt.Sprintf("ids-%d", time.Now().UnixNano()) k := fmt.Sprintf("ids-%d", time.Now().UnixNano())
if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil {
t.Fatalf("error setting data to db: %v", err) t.Fatalf("error setting data to db: %v", err)
} }
// get set data and compare // get set data and compare
resp, err := db.Get(&state.GetRequest{Key: k}) resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err) assert.Nil(t, err)
d2 := testGetTestObj(t, resp) d2 := testGetTestObj(t, resp)
assert.NotNil(t, d2) assert.NotNil(t, d2)
@ -103,12 +104,12 @@ func TestRethinkDBStateStore(t *testing.T) {
d2.F2 = 2 d2.F2 = 2
d2.F3 = time.Now().UTC() d2.F3 = time.Now().UTC()
tag := fmt.Sprintf("hash-%d", time.Now().UnixNano()) tag := fmt.Sprintf("hash-%d", time.Now().UnixNano())
if err = db.Set(&state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil { if err = db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil {
t.Fatalf("error setting data to db: %v", err) t.Fatalf("error setting data to db: %v", err)
} }
// get updated data and compare // get updated data and compare
resp2, err := db.Get(&state.GetRequest{Key: k}) resp2, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err) assert.Nil(t, err)
d3 := testGetTestObj(t, resp2) d3 := testGetTestObj(t, resp2)
assert.NotNil(t, d3) assert.NotNil(t, d3)
@ -117,7 +118,7 @@ func TestRethinkDBStateStore(t *testing.T) {
assert.Equal(t, d2.F3.Format(time.RFC3339), d3.F3.Format(time.RFC3339)) assert.Equal(t, d2.F3.Format(time.RFC3339), d3.F3.Format(time.RFC3339))
// delete data // delete data
if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil {
t.Fatalf("error on data deletion: %v", err) t.Fatalf("error on data deletion: %v", err)
} }
}) })
@ -127,19 +128,19 @@ func TestRethinkDBStateStore(t *testing.T) {
d := []byte("test") d := []byte("test")
k := fmt.Sprintf("idb-%d", time.Now().UnixNano()) k := fmt.Sprintf("idb-%d", time.Now().UnixNano())
if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil {
t.Fatalf("error setting data to db: %v", err) t.Fatalf("error setting data to db: %v", err)
} }
// get set data and compare // get set data and compare
resp, err := db.Get(&state.GetRequest{Key: k}) resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.NotNil(t, resp.Data) assert.NotNil(t, resp.Data)
assert.Equal(t, string(d), string(resp.Data)) assert.Equal(t, string(d), string(resp.Data))
// delete data // delete data
if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil {
t.Fatalf("error on data deletion: %v", err) t.Fatalf("error on data deletion: %v", err)
} }
}) })
@ -177,26 +178,26 @@ func testBulk(t *testing.T, db *RethinkDB, i int) {
} }
// bulk set it // bulk set it
if err := db.BulkSet(setList); err != nil { if err := db.BulkSet(context.TODO(), setList); err != nil {
t.Fatalf("error setting data to db: %v -- run %d", err, i) t.Fatalf("error setting data to db: %v -- run %d", err, i)
} }
// check for the data // check for the data
for _, v := range deleteList { for _, v := range deleteList {
resp, err := db.Get(&state.GetRequest{Key: v.Key}) resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key})
assert.Nilf(t, err, " -- run %d", i) assert.Nilf(t, err, " -- run %d", i)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.NotNil(t, resp.Data) assert.NotNil(t, resp.Data)
} }
// delete data // delete data
if err := db.BulkDelete(deleteList); err != nil { if err := db.BulkDelete(context.TODO(), deleteList); err != nil {
t.Fatalf("error on data deletion: %v -- run %d", err, i) t.Fatalf("error on data deletion: %v -- run %d", err, i)
} }
// check for the data NOT being there // check for the data NOT being there
for _, v := range deleteList { for _, v := range deleteList {
resp, err := db.Get(&state.GetRequest{Key: v.Key}) resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key})
assert.Nilf(t, err, " -- run %d", i) assert.Nilf(t, err, " -- run %d", i)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Nil(t, resp.Data) assert.Nil(t, resp.Data)
@ -224,7 +225,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) {
for i := 0; i < numOfRecords; i++ { for i := 0; i < numOfRecords; i++ {
list[i] = state.SetRequest{Key: fmt.Sprintf(recordIDFormat, i), Value: d} list[i] = state.SetRequest{Key: fmt.Sprintf(recordIDFormat, i), Value: d}
} }
if err := db.BulkSet(list); err != nil { if err := db.BulkSet(context.TODO(), list); err != nil {
t.Fatalf("error setting multi to db: %v", err) t.Fatalf("error setting multi to db: %v", err)
} }
@ -258,19 +259,19 @@ func TestRethinkDBStateStoreMulti(t *testing.T) {
} }
// execute multi // execute multi
if err := db.Multi(req); err != nil { if err := db.Multi(context.TODO(), req); err != nil {
t.Fatalf("error setting multi to db: %v", err) t.Fatalf("error setting multi to db: %v", err)
} }
// the one not deleted should be still there // the one not deleted should be still there
m1, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)}) m1, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, m1) assert.NotNil(t, m1)
assert.NotNil(t, m1.Data) assert.NotNil(t, m1.Data)
assert.Equal(t, string(d2), string(m1.Data)) assert.Equal(t, string(d2), string(m1.Data))
// the one deleted should not // the one deleted should not
m2, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)}) m2, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, m2) assert.NotNil(t, m2)
assert.Nil(t, m2.Data) assert.Nil(t, m2.Data)

View File

@ -14,6 +14,7 @@ limitations under the License.
package sqlserver package sqlserver
import ( import (
"context"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -338,7 +339,7 @@ func (s *SQLServer) Features() []state.Feature {
} }
// Multi performs multiple updates on a Sql server store. // Multi performs multiple updates on a Sql server store.
func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error { func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return err return err
@ -353,7 +354,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
return err return err
} }
err = s.executeSet(tx, &setReq) err = s.executeSet(ctx, tx, &setReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -366,7 +367,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
return err return err
} }
err = s.executeDelete(tx, &delReq) err = s.executeDelete(ctx, tx, &delReq)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -410,11 +411,11 @@ func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.Del
} }
// Delete removes an entity from the store. // Delete removes an entity from the store.
func (s *SQLServer) Delete(req *state.DeleteRequest) error { func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error {
return s.executeDelete(s.db, req) return s.executeDelete(ctx, s.db, req)
} }
func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error { func (s *SQLServer) executeDelete(ctx context.Context, db dbExecutor, req *state.DeleteRequest) error {
var err error var err error
var res sql.Result var res sql.Result
if req.ETag != nil { if req.ETag != nil {
@ -424,9 +425,9 @@ func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error
return state.NewETagError(state.ETagInvalid, err) return state.NewETagError(state.ETagInvalid, err)
} }
res, err = db.Exec(s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b)) res, err = db.ExecContext(ctx, s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
} else { } else {
res, err = db.Exec(s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key)) res, err = db.ExecContext(ctx, s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
} }
// err represents errors thrown by the stored procedure or the database itself // err represents errors thrown by the stored procedure or the database itself
@ -456,13 +457,13 @@ type TvpDeleteTableStringKey struct {
} }
// BulkDelete removes multiple entries from the store. // BulkDelete removes multiple entries from the store.
func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error { func (s *SQLServer) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return err return err
} }
err = s.executeBulkDelete(tx, req) err = s.executeBulkDelete(ctx, tx, req)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -474,7 +475,7 @@ func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error {
return nil return nil
} }
func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) error { func (s *SQLServer) executeBulkDelete(ctx context.Context, db dbExecutor, req []state.DeleteRequest) error {
values := make([]TvpDeleteTableStringKey, len(req)) values := make([]TvpDeleteTableStringKey, len(req))
for i, d := range req { for i, d := range req {
var etag []byte var etag []byte
@ -493,7 +494,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest)
Value: values, Value: values,
} }
res, err := db.Exec(s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete)) res, err := db.ExecContext(ctx, s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
if err != nil { if err != nil {
return err return err
} }
@ -513,7 +514,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest)
} }
// Get returns an entity from store. // Get returns an entity from store.
func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) { func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key)) rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key))
if err != nil { if err != nil {
return nil, err return nil, err
@ -545,21 +546,21 @@ func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (s *SQLServer) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil return false, nil, nil
} }
// Set adds/updates an entity on store. // Set adds/updates an entity on store.
func (s *SQLServer) Set(req *state.SetRequest) error { func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error {
return s.executeSet(s.db, req) return s.executeSet(ctx, s.db, req)
} }
// dbExecutor implements a common functionality implemented by db or tx. // dbExecutor implements a common functionality implemented by db or tx.
type dbExecutor interface { type dbExecutor interface {
Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
} }
func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error { func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error {
var err error var err error
var bytes []byte var bytes []byte
bytes, err = utils.Marshal(req.Value, json.Marshal) bytes, err = utils.Marshal(req.Value, json.Marshal)
@ -578,9 +579,9 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
var res sql.Result var res sql.Result
if req.Options.Concurrency == state.FirstWrite { if req.Options.Concurrency == state.FirstWrite {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1)) res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1))
} else { } else {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0)) res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0))
} }
if err != nil { if err != nil {
@ -604,14 +605,14 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
} }
// BulkSet adds/updates multiple entities on store. // BulkSet adds/updates multiple entities on store.
func (s *SQLServer) BulkSet(req []state.SetRequest) error { func (s *SQLServer) BulkSet(ctx context.Context, req []state.SetRequest) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return err return err
} }
for i := range req { for i := range req {
err = s.executeSet(tx, &req[i]) err = s.executeSet(ctx, tx, &req[i])
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()

View File

@ -15,6 +15,7 @@ limitations under the License.
package sqlserver package sqlserver
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -130,7 +131,7 @@ func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string)
} }
func assertUserExists(t *testing.T, store *SQLServer, key string) (user, string) { func assertUserExists(t *testing.T, store *SQLServer, key string) (user, string) {
getRes, err := store.Get(&state.GetRequest{Key: key}) getRes, err := store.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, getRes) assert.NotNil(t, getRes)
assert.NotNil(t, getRes.Data, "No data was returned") assert.NotNil(t, getRes.Data, "No data was returned")
@ -153,7 +154,7 @@ func assertLoadedUserIsEqual(t *testing.T, store *SQLServer, key string, expecte
} }
func assertUserDoesNotExist(t *testing.T, store *SQLServer, key string) { func assertUserDoesNotExist(t *testing.T, store *SQLServer, key string) {
_, err := store.Get(&state.GetRequest{Key: key}) _, err := store.Get(context.TODO(), &state.GetRequest{Key: key})
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -224,14 +225,14 @@ func testSingleOperations(t *testing.T) {
assertUserDoesNotExist(t, store, john.ID) assertUserDoesNotExist(t, store, john.ID)
// Save and read // Save and read
err := store.Set(&state.SetRequest{Key: john.ID, Value: john}) err := store.Set(context.TODO(), &state.SetRequest{Key: john.ID, Value: john})
assert.Nil(t, err) assert.Nil(t, err)
johnV1, etagFromInsert := assertLoadedUserIsEqual(t, store, john.ID, john) johnV1, etagFromInsert := assertLoadedUserIsEqual(t, store, john.ID, john)
// Update with ETAG // Update with ETAG
waterJohn := johnV1 waterJohn := johnV1
waterJohn.FavoriteBeverage = "Water" waterJohn.FavoriteBeverage = "Water"
err = store.Set(&state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert}) err = store.Set(context.TODO(), &state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert})
assert.Nil(t, err) assert.Nil(t, err)
// Get updated // Get updated
@ -240,7 +241,7 @@ func testSingleOperations(t *testing.T) {
// Update without ETAG // Update without ETAG
noEtagJohn := johnV2 noEtagJohn := johnV2
noEtagJohn.FavoriteBeverage = "No Etag John" noEtagJohn.FavoriteBeverage = "No Etag John"
err = store.Set(&state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn}) err = store.Set(context.TODO(), &state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn})
assert.Nil(t, err) assert.Nil(t, err)
// 7. Get updated // 7. Get updated
@ -249,17 +250,17 @@ func testSingleOperations(t *testing.T) {
// 8. Update with invalid ETAG should fail // 8. Update with invalid ETAG should fail
failedJohn := johnV3 failedJohn := johnV3
failedJohn.FavoriteBeverage = "Will not work" failedJohn.FavoriteBeverage = "Will not work"
err = store.Set(&state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert}) err = store.Set(context.TODO(), &state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert})
assert.NotNil(t, err) assert.NotNil(t, err)
_, etag := assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) _, etag := assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3)
// 9. Delete with invalid ETAG should fail // 9. Delete with invalid ETAG should fail
err = store.Delete(&state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag}) err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag})
assert.NotNil(t, err) assert.NotNil(t, err)
assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3)
// 10. Delete with valid ETAG // 10. Delete with valid ETAG
err = store.Delete(&state.DeleteRequest{Key: johnV2.ID, ETag: &etag}) err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV2.ID, ETag: &etag})
assert.Nil(t, err) assert.Nil(t, err)
assertUserDoesNotExist(t, store, johnV2.ID) assertUserDoesNotExist(t, store, johnV2.ID)
@ -273,7 +274,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) {
u := user{uuid.New().String(), "John", "Coffee"} u := user{uuid.New().String(), "John", "Coffee"}
invEtag := invalidEtag invEtag := invalidEtag
err := store.Set(&state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag}) err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag})
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -281,7 +282,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) {
func testIndexedProperties(t *testing.T) { func testIndexedProperties(t *testing.T) {
store := getTestStore(t, `[{ "column":"FavoriteBeverage", "property":"FavoriteBeverage", "type":"nvarchar(100)"}, { "column":"PetsCount", "property":"PetsCount", "type": "INTEGER"}]`) store := getTestStore(t, `[{ "column":"FavoriteBeverage", "property":"FavoriteBeverage", "type":"nvarchar(100)"}, { "column":"PetsCount", "property":"PetsCount", "type": "INTEGER"}]`)
err := store.BulkSet([]state.SetRequest{ err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: "1", Value: userWithPets{user{"1", "John", "Coffee"}, 3}}, {Key: "1", Value: userWithPets{user{"1", "John", "Coffee"}, 3}},
{Key: "2", Value: userWithPets{user{"2", "Laura", "Water"}, 1}}, {Key: "2", Value: userWithPets{user{"2", "Laura", "Water"}, 1}},
{Key: "3", Value: userWithPets{user{"3", "Carl", "Beer"}, 0}}, {Key: "3", Value: userWithPets{user{"3", "Carl", "Beer"}, 0}},
@ -343,7 +344,7 @@ func testMultiOperations(t *testing.T) {
bulkSet[i] = state.SetRequest{Key: u.ID, Value: u} bulkSet[i] = state.SetRequest{Key: u.ID, Value: u}
} }
err := store.BulkSet(bulkSet) err := store.BulkSet(context.TODO(), bulkSet)
assert.Nil(t, err) assert.Nil(t, err)
assertUserCountIsEqualTo(t, store, len(initialUsers)) assertUserCountIsEqualTo(t, store, len(initialUsers))
@ -363,7 +364,7 @@ func testMultiOperations(t *testing.T) {
modified := original.user modified := original.user
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
localErr := store.Multi(&state.TransactionalStateRequest{ localErr := store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}},
@ -386,7 +387,7 @@ func testMultiOperations(t *testing.T) {
modified := toModify.user modified := toModify.user
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
err = store.Multi(&state.TransactionalStateRequest{ err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}},
@ -410,7 +411,7 @@ func testMultiOperations(t *testing.T) {
modified := toModify.user modified := toModify.user
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
err = store.Multi(&state.TransactionalStateRequest{ err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}},
@ -431,7 +432,7 @@ func testMultiOperations(t *testing.T) {
toInsert := user{keyGen.NextKey(), "Wont-be-inserted", "Beer"} toInsert := user{keyGen.NextKey(), "Wont-be-inserted", "Beer"}
invEtag := invalidEtag invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{ err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: toInsert.ID, Value: toInsert}}, {Operation: state.Upsert, Request: state.SetRequest{Key: toInsert.ID, Value: toInsert}},
@ -452,7 +453,7 @@ func testMultiOperations(t *testing.T) {
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
invEtag := invalidEtag invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{ err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}},
@ -472,7 +473,7 @@ func testMultiOperations(t *testing.T) {
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
invEtag := invalidEtag invEtag := invalidEtag
err = store.Multi(&state.TransactionalStateRequest{ err = store.Multi(context.TODO(), &state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{ Operations: []state.TransactionalStateOperation{
{Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}}, {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}},
{Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &invEtag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &invEtag}},
@ -520,7 +521,7 @@ func testBulkSet(t *testing.T) {
sets[i] = state.SetRequest{Key: u.ID, Value: u} sets[i] = state.SetRequest{Key: u.ID, Value: u}
} }
err := store.BulkSet(sets) err := store.BulkSet(context.TODO(), sets)
assert.Nil(t, err) assert.Nil(t, err)
totalUsers = len(sets) totalUsers = len(sets)
assertUserCountIsEqualTo(t, store, totalUsers) assertUserCountIsEqualTo(t, store, totalUsers)
@ -532,7 +533,7 @@ func testBulkSet(t *testing.T) {
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
toInsert := user{keyGen.NextKey(), "Maria", "Wine"} toInsert := user{keyGen.NextKey(), "Maria", "Wine"}
err := store.BulkSet([]state.SetRequest{ err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: modified.ID, Value: modified, ETag: &toModifyETag}, {Key: modified.ID, Value: modified, ETag: &toModifyETag},
{Key: toInsert.ID, Value: toInsert}, {Key: toInsert.ID, Value: toInsert},
}) })
@ -551,7 +552,7 @@ func testBulkSet(t *testing.T) {
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
toInsert := user{keyGen.NextKey(), "Tony", "Milk"} toInsert := user{keyGen.NextKey(), "Tony", "Milk"}
err := store.BulkSet([]state.SetRequest{ err := store.BulkSet(context.TODO(), []state.SetRequest{
{Key: modified.ID, Value: modified}, {Key: modified.ID, Value: modified},
{Key: toInsert.ID, Value: toInsert}, {Key: toInsert.ID, Value: toInsert},
}) })
@ -578,7 +579,7 @@ func testBulkSet(t *testing.T) {
{Key: modified.ID, Value: modified, ETag: &invEtag}, {Key: modified.ID, Value: modified, ETag: &invEtag},
} }
err := store.BulkSet(sets) err := store.BulkSet(context.TODO(), sets)
assert.NotNil(t, err) assert.NotNil(t, err)
assertUserCountIsEqualTo(t, store, totalUsers) assertUserCountIsEqualTo(t, store, totalUsers)
assertUserDoesNotExist(t, store, toInsert1.ID) assertUserDoesNotExist(t, store, toInsert1.ID)
@ -621,7 +622,7 @@ func testBulkDelete(t *testing.T) {
for i, u := range initialUsers { for i, u := range initialUsers {
sets[i] = state.SetRequest{Key: u.ID, Value: u} sets[i] = state.SetRequest{Key: u.ID, Value: u}
} }
err := store.BulkSet(sets) err := store.BulkSet(context.TODO(), sets)
assert.Nil(t, err) assert.Nil(t, err)
totalUsers := len(initialUsers) totalUsers := len(initialUsers)
assertUserCountIsEqualTo(t, store, totalUsers) assertUserCountIsEqualTo(t, store, totalUsers)
@ -631,7 +632,7 @@ func testBulkDelete(t *testing.T) {
t.Run("Delete 2 items without etag should work", func(t *testing.T) { t.Run("Delete 2 items without etag should work", func(t *testing.T) {
deleted1 := initialUsers[userIndex].ID deleted1 := initialUsers[userIndex].ID
deleted2 := initialUsers[userIndex+1].ID deleted2 := initialUsers[userIndex+1].ID
err := store.BulkDelete([]state.DeleteRequest{ err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1}, {Key: deleted1},
{Key: deleted2}, {Key: deleted2},
}) })
@ -648,7 +649,7 @@ func testBulkDelete(t *testing.T) {
deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID)
deleted2, deleted2Etag := assertUserExists(t, store, initialUsers[userIndex+1].ID) deleted2, deleted2Etag := assertUserExists(t, store, initialUsers[userIndex+1].ID)
err := store.BulkDelete([]state.DeleteRequest{ err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID, ETag: &deleted2Etag}, {Key: deleted2.ID, ETag: &deleted2Etag},
}) })
@ -665,7 +666,7 @@ func testBulkDelete(t *testing.T) {
deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID)
deleted2 := initialUsers[userIndex+1] deleted2 := initialUsers[userIndex+1]
err := store.BulkDelete([]state.DeleteRequest{ err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID}, {Key: deleted2.ID},
}) })
@ -683,7 +684,7 @@ func testBulkDelete(t *testing.T) {
deleted2 := initialUsers[userIndex+1] deleted2 := initialUsers[userIndex+1]
invEtag := invalidEtag invEtag := invalidEtag
err := store.BulkDelete([]state.DeleteRequest{ err := store.BulkDelete(context.TODO(), []state.DeleteRequest{
{Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted1.ID, ETag: &deleted1Etag},
{Key: deleted2.ID, ETag: &invEtag}, {Key: deleted2.ID, ETag: &invEtag},
}) })
@ -703,7 +704,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) {
store := getTestStore(t, "") store := getTestStore(t, "")
u := user{"1", "John", "Coffee"} u := user{"1", "John", "Coffee"}
err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u})
assert.Nil(t, err) assert.Nil(t, err)
var originalInsertTime time.Time var originalInsertTime time.Time
@ -725,7 +726,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) {
modified := u modified := u
modified.FavoriteBeverage = beverageTea modified.FavoriteBeverage = beverageTea
err = store.Set(&state.SetRequest{Key: modified.ID, Value: modified}) err = store.Set(context.TODO(), &state.SetRequest{Key: modified.ID, Value: modified})
assert.Nil(t, err) assert.Nil(t, err)
assertDBQuery(t, store, getUserTsql, func(t *testing.T, rows *sql.Rows) { assertDBQuery(t, store, getUserTsql, func(t *testing.T, rows *sql.Rows) {
assert.True(t, rows.Next()) assert.True(t, rows.Next())
@ -749,7 +750,7 @@ func testConcurrentSets(t *testing.T) {
store := getTestStore(t, "") store := getTestStore(t, "")
u := user{"1", "John", "Coffee"} u := user{"1", "John", "Coffee"}
err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u})
assert.Nil(t, err) assert.Nil(t, err)
_, etag := assertLoadedUserIsEqual(t, store, u.ID, u) _, etag := assertLoadedUserIsEqual(t, store, u.ID, u)
@ -766,7 +767,7 @@ func testConcurrentSets(t *testing.T) {
defer wc.Done() defer wc.Done()
modified := user{"1", "John", beverageTea} modified := user{"1", "John", beverageTea}
err := store.Set(&state.SetRequest{Key: id, Value: modified, ETag: &etag}) err := store.Set(context.TODO(), &state.SetRequest{Key: id, Value: modified, ETag: &etag})
if err != nil { if err != nil {
atomic.AddInt32(&totalErrors, 1) atomic.AddInt32(&totalErrors, 1)
} else { } else {

View File

@ -14,6 +14,7 @@ limitations under the License.
package state package state
import ( import (
"context"
"fmt" "fmt"
"github.com/dapr/components-contrib/health" "github.com/dapr/components-contrib/health"
@ -24,9 +25,9 @@ type Store interface {
BulkStore BulkStore
Init(metadata Metadata) error Init(metadata Metadata) error
Features() []Feature Features() []Feature
Delete(req *DeleteRequest) error Delete(ctx context.Context, req *DeleteRequest) error
Get(req *GetRequest) (*GetResponse, error) Get(ctx context.Context, req *GetRequest) (*GetResponse, error)
Set(req *SetRequest) error Set(ctx context.Context, req *SetRequest) error
} }
func Ping(store Store) error { func Ping(store Store) error {
@ -40,9 +41,9 @@ func Ping(store Store) error {
// BulkStore is an interface to perform bulk operations on store. // BulkStore is an interface to perform bulk operations on store.
type BulkStore interface { type BulkStore interface {
BulkDelete(req []DeleteRequest) error BulkDelete(ctx context.Context, req []DeleteRequest) error
BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error)
BulkSet(req []SetRequest) error BulkSet(ctx context.Context, req []SetRequest) error
} }
// DefaultBulkStore is a default implementation of BulkStore. // DefaultBulkStore is a default implementation of BulkStore.
@ -64,16 +65,16 @@ func (b *DefaultBulkStore) Features() []Feature {
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (b *DefaultBulkStore) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { func (b *DefaultBulkStore) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) {
// by default, the store doesn't support bulk get // by default, the store doesn't support bulk get
// return false so daprd will fallback to call get() method one by one // return false so daprd will fallback to call get() method one by one
return false, nil, nil return false, nil, nil
} }
// BulkSet performs a bulks save operation. // BulkSet performs a bulks save operation.
func (b *DefaultBulkStore) BulkSet(req []SetRequest) error { func (b *DefaultBulkStore) BulkSet(ctx context.Context, req []SetRequest) error {
for i := range req { for i := range req {
err := b.s.Set(&req[i]) err := b.s.Set(ctx, &req[i])
if err != nil { if err != nil {
return err return err
} }
@ -83,9 +84,9 @@ func (b *DefaultBulkStore) BulkSet(req []SetRequest) error {
} }
// BulkDelete performs a bulk delete operation. // BulkDelete performs a bulk delete operation.
func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error { func (b *DefaultBulkStore) BulkDelete(ctx context.Context, req []DeleteRequest) error {
for i := range req { for i := range req {
err := b.s.Delete(&req[i]) err := b.s.Delete(ctx, &req[i])
if err != nil { if err != nil {
return err return err
} }
@ -96,5 +97,5 @@ func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error {
// Querier is an interface to execute queries. // Querier is an interface to execute queries.
type Querier interface { type Querier interface {
Query(req *QueryRequest) (*QueryResponse, error) Query(ctx context.Context, req *QueryRequest) (*QueryResponse, error)
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package state package state
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -26,22 +27,22 @@ func TestStore_withDefaultBulkImpl(t *testing.T) {
require.Equal(t, s.count, 0) require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0) require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{}) store.Get(context.TODO(), &GetRequest{})
store.Set(&SetRequest{}) store.Set(context.TODO(), &SetRequest{})
store.Delete(&DeleteRequest{}) store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
bulkGet, responses, err := store.BulkGet([]GetRequest{{}, {}, {}}) bulkGet, responses, err := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, false, bulkGet) require.Equal(t, false, bulkGet)
require.Equal(t, 0, len(responses)) require.Equal(t, 0, len(responses))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}}) store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 3+4, s.count) require.Equal(t, 3+4, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 3+4+5, s.count) require.Equal(t, 3+4+5, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
} }
@ -52,20 +53,20 @@ func TestStore_withCustomisedBulkImpl_notSupportBulkGet(t *testing.T) {
require.Equal(t, s.count, 0) require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0) require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{}) store.Get(context.TODO(), &GetRequest{})
store.Set(&SetRequest{}) store.Set(context.TODO(), &SetRequest{})
store.Delete(&DeleteRequest{}) store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, false, bulkGet) require.Equal(t, false, bulkGet)
require.Equal(t, 6, s.count) require.Equal(t, 6, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}}) store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 6, s.count) require.Equal(t, 6, s.count)
require.Equal(t, 1, s.bulkCount) require.Equal(t, 1, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 6, s.count) require.Equal(t, 6, s.count)
require.Equal(t, 2, s.bulkCount) require.Equal(t, 2, s.bulkCount)
} }
@ -76,20 +77,20 @@ func TestStore_withCustomisedBulkImpl_supportBulkGet(t *testing.T) {
require.Equal(t, s.count, 0) require.Equal(t, s.count, 0)
require.Equal(t, s.bulkCount, 0) require.Equal(t, s.bulkCount, 0)
store.Get(&GetRequest{}) store.Get(context.TODO(), &GetRequest{})
store.Set(&SetRequest{}) store.Set(context.TODO(), &SetRequest{})
store.Delete(&DeleteRequest{}) store.Delete(context.TODO(), &DeleteRequest{})
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 0, s.bulkCount) require.Equal(t, 0, s.bulkCount)
bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}})
require.Equal(t, true, bulkGet) require.Equal(t, true, bulkGet)
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 1, s.bulkCount) require.Equal(t, 1, s.bulkCount)
store.BulkSet([]SetRequest{{}, {}, {}, {}}) store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}})
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 2, s.bulkCount) require.Equal(t, 2, s.bulkCount)
store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}})
require.Equal(t, 3, s.count) require.Equal(t, 3, s.count)
require.Equal(t, 3, s.bulkCount) require.Equal(t, 3, s.bulkCount)
} }
@ -110,19 +111,19 @@ func (s *Store1) Init(metadata Metadata) error {
return nil return nil
} }
func (s *Store1) Delete(req *DeleteRequest) error { func (s *Store1) Delete(ctx context.Context, req *DeleteRequest) error {
s.count++ s.count++
return nil return nil
} }
func (s *Store1) Get(req *GetRequest) (*GetResponse, error) { func (s *Store1) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) {
s.count++ s.count++
return &GetResponse{}, nil return &GetResponse{}, nil
} }
func (s *Store1) Set(req *SetRequest) error { func (s *Store1) Set(ctx context.Context, req *SetRequest) error {
s.count++ s.count++
return nil return nil
@ -145,25 +146,25 @@ func (s *Store2) Features() []Feature {
return nil return nil
} }
func (s *Store2) Delete(req *DeleteRequest) error { func (s *Store2) Delete(ctx context.Context, req *DeleteRequest) error {
s.count++ s.count++
return nil return nil
} }
func (s *Store2) Get(req *GetRequest) (*GetResponse, error) { func (s *Store2) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) {
s.count++ s.count++
return &GetResponse{}, nil return &GetResponse{}, nil
} }
func (s *Store2) Set(req *SetRequest) error { func (s *Store2) Set(ctx context.Context, req *SetRequest) error {
s.count++ s.count++
return nil return nil
} }
func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { func (s *Store2) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) {
if s.supportBulkGet { if s.supportBulkGet {
s.bulkCount++ s.bulkCount++
@ -175,13 +176,13 @@ func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) {
return false, nil, nil return false, nil, nil
} }
func (s *Store2) BulkSet(req []SetRequest) error { func (s *Store2) BulkSet(ctx context.Context, req []SetRequest) error {
s.bulkCount++ s.bulkCount++
return nil return nil
} }
func (s *Store2) BulkDelete(req []DeleteRequest) error { func (s *Store2) BulkDelete(ctx context.Context, req []DeleteRequest) error {
s.bulkCount++ s.bulkCount++
return nil return nil

View File

@ -14,6 +14,7 @@ limitations under the License.
package zookeeper package zookeeper
import ( import (
"context"
"errors" "errors"
"path" "path"
"strconv" "strconv"
@ -161,7 +162,7 @@ func (s *StateStore) Features() []state.Feature {
} }
// Get retrieves state from Zookeeper with a key. // Get retrieves state from Zookeeper with a key.
func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { func (s *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
value, stat, err := s.conn.Get(s.prefixedKey(req.Key)) value, stat, err := s.conn.Get(s.prefixedKey(req.Key))
if err != nil { if err != nil {
if errors.Is(err, zk.ErrNoNode) { if errors.Is(err, zk.ErrNoNode) {
@ -178,19 +179,19 @@ func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
} }
// BulkGet performs a bulks get operations. // BulkGet performs a bulks get operations.
func (s *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { func (s *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// TODO: replace with Multi for performance // TODO: replace with Multi for performance
return false, nil, nil return false, nil, nil
} }
// Delete performs a delete operation. // Delete performs a delete operation.
func (s *StateStore) Delete(req *state.DeleteRequest) error { func (s *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
r, err := s.newDeleteRequest(req) r, err := s.newDeleteRequest(req)
if err != nil { if err != nil {
return err return err
} }
return state.DeleteWithOptions(func(req *state.DeleteRequest) error { return state.DeleteWithOptions(ctx, func(ctx context.Context, req *state.DeleteRequest) error {
err := s.conn.Delete(r.Path, r.Version) err := s.conn.Delete(r.Path, r.Version)
if errors.Is(err, zk.ErrNoNode) { if errors.Is(err, zk.ErrNoNode) {
return nil return nil
@ -209,7 +210,7 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
} }
// BulkDelete performs a bulk delete operation. // BulkDelete performs a bulk delete operation.
func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error { func (s *StateStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error {
ops := make([]interface{}, 0, len(reqs)) ops := make([]interface{}, 0, len(reqs))
for i := range reqs { for i := range reqs {
@ -236,13 +237,13 @@ func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error {
} }
// Set saves state into Zookeeper. // Set saves state into Zookeeper.
func (s *StateStore) Set(req *state.SetRequest) error { func (s *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
r, err := s.newSetDataRequest(req) r, err := s.newSetDataRequest(req)
if err != nil { if err != nil {
return err return err
} }
return state.SetWithOptions(func(req *state.SetRequest) error { return state.SetWithOptions(ctx, func(ctx context.Context, req *state.SetRequest) error {
_, err = s.conn.Set(r.Path, r.Data, r.Version) _, err = s.conn.Set(r.Path, r.Data, r.Version)
if errors.Is(err, zk.ErrNoNode) { if errors.Is(err, zk.ErrNoNode) {
@ -262,7 +263,7 @@ func (s *StateStore) Set(req *state.SetRequest) error {
} }
// BulkSet performs a bulks save operation. // BulkSet performs a bulks save operation.
func (s *StateStore) BulkSet(reqs []state.SetRequest) error { func (s *StateStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error {
ops := make([]interface{}, 0, len(reqs)) ops := make([]interface{}, 0, len(reqs))
for i := range reqs { for i := range reqs {

View File

@ -14,6 +14,7 @@ limitations under the License.
package zookeeper package zookeeper
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -80,7 +81,7 @@ func TestGet(t *testing.T) {
t.Run("With key exists", func(t *testing.T) { t.Run("With key exists", func(t *testing.T) {
conn.EXPECT().Get("foo").Return([]byte("bar"), &zk.Stat{Version: 123}, nil).Times(1) conn.EXPECT().Get("foo").Return([]byte("bar"), &zk.Stat{Version: 123}, nil).Times(1)
res, err := s.Get(&state.GetRequest{Key: "foo"}) res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"})
assert.NotNil(t, res, "Key must be exists") assert.NotNil(t, res, "Key must be exists")
assert.Equal(t, "bar", string(res.Data), "Value must be equals") assert.Equal(t, "bar", string(res.Data), "Value must be equals")
assert.Equal(t, ptr.String("123"), res.ETag, "ETag must be equals") assert.Equal(t, ptr.String("123"), res.ETag, "ETag must be equals")
@ -90,7 +91,7 @@ func TestGet(t *testing.T) {
t.Run("With key non-exists", func(t *testing.T) { t.Run("With key non-exists", func(t *testing.T) {
conn.EXPECT().Get("foo").Return(nil, nil, zk.ErrNoNode).Times(1) conn.EXPECT().Get("foo").Return(nil, nil, zk.ErrNoNode).Times(1)
res, err := s.Get(&state.GetRequest{Key: "foo"}) res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"})
assert.Equal(t, &state.GetResponse{}, res, "Response must be empty") assert.Equal(t, &state.GetResponse{}, res, "Response must be empty")
assert.NoError(t, err, "Non-existent key must not be treated as error") assert.NoError(t, err, "Non-existent key must not be treated as error")
}) })
@ -108,21 +109,21 @@ func TestDelete(t *testing.T) {
t.Run("With key", func(t *testing.T) { t.Run("With key", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.NoError(t, err, "Key must be exists") assert.NoError(t, err, "Key must be exists")
}) })
t.Run("With key and version", func(t *testing.T) { t.Run("With key and version", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(123)).Return(nil).Times(1) conn.EXPECT().Delete("foo", int32(123)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo", ETag: &etag}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo", ETag: &etag})
assert.NoError(t, err, "Key must be exists") assert.NoError(t, err, "Key must be exists")
}) })
t.Run("With key and concurrency", func(t *testing.T) { t.Run("With key and concurrency", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1)
err := s.Delete(&state.DeleteRequest{ err := s.Delete(context.TODO(), &state.DeleteRequest{
Key: "foo", Key: "foo",
ETag: &etag, ETag: &etag,
Options: state.DeleteStateOption{Concurrency: state.LastWrite}, Options: state.DeleteStateOption{Concurrency: state.LastWrite},
@ -133,14 +134,14 @@ func TestDelete(t *testing.T) {
t.Run("With delete error", func(t *testing.T) { t.Run("With delete error", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrUnknown).Times(1) conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrUnknown).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.EqualError(t, err, "zk: unknown error") assert.EqualError(t, err, "zk: unknown error")
}) })
t.Run("With delete and ignore NoNode error", func(t *testing.T) { t.Run("With delete and ignore NoNode error", func(t *testing.T) {
conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrNoNode).Times(1) conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrNoNode).Times(1)
err := s.Delete(&state.DeleteRequest{Key: "foo"}) err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"})
assert.NoError(t, err, "Delete must be successful") assert.NoError(t, err, "Delete must be successful")
}) })
} }
@ -159,7 +160,7 @@ func TestBulkDelete(t *testing.T) {
&zk.DeleteRequest{Path: "bar", Version: int32(anyVersion)}, &zk.DeleteRequest{Path: "bar", Version: int32(anyVersion)},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.NoError(t, err, "Key must be exists") assert.NoError(t, err, "Key must be exists")
}) })
@ -171,7 +172,7 @@ func TestBulkDelete(t *testing.T) {
{Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth},
}, nil).Times(1) }, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.Equal(t, err.(*multierror.Error).Errors, []error{zk.ErrUnknown, zk.ErrNoAuth}) assert.Equal(t, err.(*multierror.Error).Errors, []error{zk.ErrUnknown, zk.ErrNoAuth})
}) })
t.Run("With keys and ignore NoNode error", func(t *testing.T) { t.Run("With keys and ignore NoNode error", func(t *testing.T) {
@ -182,7 +183,7 @@ func TestBulkDelete(t *testing.T) {
{Error: zk.ErrNoNode}, {}, {Error: zk.ErrNoNode}, {},
}, nil).Times(1) }, nil).Times(1)
err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}})
assert.NoError(t, err, "Key must be exists") assert.NoError(t, err, "Key must be exists")
}) })
} }
@ -201,19 +202,19 @@ func TestSet(t *testing.T) {
t.Run("With key", func(t *testing.T) { t.Run("With key", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.NoError(t, err, "Key must be set") assert.NoError(t, err, "Key must be set")
}) })
t.Run("With key and version", func(t *testing.T) { t.Run("With key and version", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(123)).Return(stat, nil).Times(1) conn.EXPECT().Set("foo", []byte("\"bar\""), int32(123)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar", ETag: &etag}) err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar", ETag: &etag})
assert.NoError(t, err, "Key must be set") assert.NoError(t, err, "Key must be set")
}) })
t.Run("With key and concurrency", func(t *testing.T) { t.Run("With key and concurrency", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1)
err := s.Set(&state.SetRequest{ err := s.Set(context.TODO(), &state.SetRequest{
Key: "foo", Key: "foo",
Value: "bar", Value: "bar",
ETag: &etag, ETag: &etag,
@ -225,14 +226,14 @@ func TestSet(t *testing.T) {
t.Run("With error", func(t *testing.T) { t.Run("With error", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrUnknown).Times(1) conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrUnknown).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.EqualError(t, err, "zk: unknown error") assert.EqualError(t, err, "zk: unknown error")
}) })
t.Run("With NoNode error and retry", func(t *testing.T) { t.Run("With NoNode error and retry", func(t *testing.T) {
conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrNoNode).Times(1) conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrNoNode).Times(1)
conn.EXPECT().Create("foo", []byte("\"bar\""), int32(0), nil).Return("/foo", nil).Times(1) conn.EXPECT().Create("foo", []byte("\"bar\""), int32(0), nil).Return("/foo", nil).Times(1)
err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"})
assert.NoError(t, err, "Key must be create") assert.NoError(t, err, "Key must be create")
}) })
} }
@ -251,7 +252,7 @@ func TestBulkSet(t *testing.T) {
&zk.SetDataRequest{Path: "bar", Data: []byte("\"foo\""), Version: int32(anyVersion)}, &zk.SetDataRequest{Path: "bar", Data: []byte("\"foo\""), Version: int32(anyVersion)},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkSet([]state.SetRequest{ err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"}, {Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"}, {Key: "bar", Value: "foo"},
}) })
@ -266,7 +267,7 @@ func TestBulkSet(t *testing.T) {
{Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth},
}, nil).Times(1) }, nil).Times(1)
err := s.BulkSet([]state.SetRequest{ err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"}, {Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"}, {Key: "bar", Value: "foo"},
}) })
@ -283,7 +284,7 @@ func TestBulkSet(t *testing.T) {
&zk.CreateRequest{Path: "foo", Data: []byte("\"bar\"")}, &zk.CreateRequest{Path: "foo", Data: []byte("\"bar\"")},
}).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1)
err := s.BulkSet([]state.SetRequest{ err := s.BulkSet(context.TODO(), []state.SetRequest{
{Key: "foo", Value: "bar"}, {Key: "foo", Value: "bar"},
{Key: "bar", Value: "foo"}, {Key: "bar", Value: "foo"},
}) })

View File

@ -14,6 +14,7 @@ limitations under the License.
package state package state
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
@ -251,7 +252,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 { if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
} }
err := statestore.Set(req) err := statestore.Set(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
} }
} }
@ -269,7 +270,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 { if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
} }
res, err := statestore.Get(req) res, err := statestore.Get(context.TODO(), req)
assert.Nil(t, err) assert.Nil(t, err)
assertEquals(t, scenario.value, res) assertEquals(t, scenario.value, res)
} }
@ -290,7 +291,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
metadata.ContentType: contenttype.JSONContentType, metadata.ContentType: contenttype.JSONContentType,
metadata.QueryIndexName: "qIndx", metadata.QueryIndexName: "qIndx",
} }
resp, err := querier.Query(&req) resp, err := querier.Query(context.TODO(), &req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, len(scenario.results), len(resp.Results)) assert.Equal(t, len(scenario.results), len(resp.Results))
for i := range scenario.results { for i := range scenario.results {
@ -318,11 +319,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if len(scenario.contentType) != 0 { if len(scenario.contentType) != 0 {
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
} }
err := statestore.Delete(req) err := statestore.Delete(context.TODO(), req)
assert.Nil(t, err, "no error expected while deleting %s", scenario.key) assert.Nil(t, err, "no error expected while deleting %s", scenario.key)
t.Logf("Checking value absence for %s", scenario.key) t.Logf("Checking value absence for %s", scenario.key)
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key, Key: scenario.key,
}) })
assert.Nil(t, err, "no error expected while checking for absence for %s", scenario.key) assert.Nil(t, err, "no error expected while checking for absence for %s", scenario.key)
@ -344,14 +345,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
}) })
} }
} }
err := statestore.BulkSet(bulk) err := statestore.BulkSet(context.TODO(), bulk)
assert.Nil(t, err) assert.Nil(t, err)
for _, scenario := range scenarios { for _, scenario := range scenarios {
if scenario.bulkOnly { if scenario.bulkOnly {
t.Logf("Checking value presence for %s", scenario.key) t.Logf("Checking value presence for %s", scenario.key)
// Data should have been inserted at this point // Data should have been inserted at this point
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key, Key: scenario.key,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -372,12 +373,12 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
}) })
} }
} }
err := statestore.BulkDelete(bulk) err := statestore.BulkDelete(context.TODO(), bulk)
assert.Nil(t, err) assert.Nil(t, err)
for _, req := range bulk { for _, req := range bulk {
t.Logf("Checking value absence for %s", req.Key) t.Logf("Checking value absence for %s", req.Key)
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: req.Key, Key: req.Key,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -443,7 +444,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if scenario.transactionGroup == transactionGroup { if scenario.transactionGroup == transactionGroup {
t.Logf("Checking value presence for %s", scenario.key) t.Logf("Checking value presence for %s", scenario.key)
// Data should have been inserted at this point // Data should have been inserted at this point
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key, Key: scenario.key,
// For CosmosDB // For CosmosDB
Metadata: map[string]string{ Metadata: map[string]string{
@ -457,7 +458,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
if scenario.toBeDeleted && (scenario.transactionGroup == transactionGroup-1) { if scenario.toBeDeleted && (scenario.transactionGroup == transactionGroup-1) {
t.Logf("Checking value absence for %s", scenario.key) t.Logf("Checking value absence for %s", scenario.key)
// Data should have been deleted at this point // Data should have been deleted at this point
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: scenario.key, Key: scenario.key,
// For CosmosDB // For CosmosDB
Metadata: map[string]string{ Metadata: map[string]string{
@ -487,7 +488,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
} }
// prerequisite: key1 should be present // prerequisite: key1 should be present
err := statestore.Set(&state.SetRequest{ err := statestore.Set(context.TODO(), &state.SetRequest{
Key: firstKey, Key: firstKey,
Value: firstValue, Value: firstValue,
Metadata: partitionMetadata, Metadata: partitionMetadata,
@ -495,14 +496,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
assert.NoError(t, err, "set request should be successful") assert.NoError(t, err, "set request should be successful")
// prerequisite: key2 should not be present // prerequisite: key2 should not be present
err = statestore.Delete(&state.DeleteRequest{ err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: secondKey, Key: secondKey,
Metadata: partitionMetadata, Metadata: partitionMetadata,
}) })
assert.NoError(t, err, "delete request should be successful") assert.NoError(t, err, "delete request should be successful")
// prerequisite: key3 should not be present // prerequisite: key3 should not be present
err = statestore.Delete(&state.DeleteRequest{ err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: thirdKey, Key: thirdKey,
Metadata: partitionMetadata, Metadata: partitionMetadata,
}) })
@ -558,7 +559,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
// Assert // Assert
for k, v := range expected { for k, v := range expected {
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: k, Key: k,
Metadata: partitionMetadata, Metadata: partitionMetadata,
}) })
@ -585,20 +586,20 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.True(t, state.FeatureETag.IsPresent(features)) require.True(t, state.FeatureETag.IsPresent(features))
// Delete any potential object, it's important to start from a clean slate. // Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{ err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
// Set an object. // Set an object.
err = statestore.Set(&state.SetRequest{ err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey, Key: testKey,
Value: firstValue, Value: firstValue,
}) })
require.Nil(t, err) require.Nil(t, err)
// Validate the set. // Validate the set.
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey, Key: testKey,
}) })
@ -607,7 +608,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
etag := res.ETag etag := res.ETag
// Try and update with wrong ETag, expect failure. // Try and update with wrong ETag, expect failure.
err = statestore.Set(&state.SetRequest{ err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey, Key: testKey,
Value: secondValue, Value: secondValue,
ETag: &fakeEtag, ETag: &fakeEtag,
@ -615,7 +616,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.NotNil(t, err) require.NotNil(t, err)
// Try and update with corect ETag, expect success. // Try and update with corect ETag, expect success.
err = statestore.Set(&state.SetRequest{ err = statestore.Set(context.TODO(), &state.SetRequest{
Key: testKey, Key: testKey,
Value: secondValue, Value: secondValue,
ETag: etag, ETag: etag,
@ -623,7 +624,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
require.Nil(t, err) require.Nil(t, err)
// Validate the set. // Validate the set.
res, err = statestore.Get(&state.GetRequest{ res, err = statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
@ -632,14 +633,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
etag = res.ETag etag = res.ETag
// Try and delete with wrong ETag, expect failure. // Try and delete with wrong ETag, expect failure.
err = statestore.Delete(&state.DeleteRequest{ err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey, Key: testKey,
ETag: &fakeEtag, ETag: &fakeEtag,
}) })
require.NotNil(t, err) require.NotNil(t, err)
// Try and delete with correct ETag, expect success. // Try and delete with correct ETag, expect success.
err = statestore.Delete(&state.DeleteRequest{ err = statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey, Key: testKey,
ETag: etag, ETag: etag,
}) })
@ -698,23 +699,23 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
for i, requestSet := range requestSets { for i, requestSet := range requestSets {
t.Run(fmt.Sprintf("request set %d", i), func(t *testing.T) { t.Run(fmt.Sprintf("request set %d", i), func(t *testing.T) {
// Delete any potential object, it's important to start from a clean slate. // Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{ err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
err = statestore.Set(requestSet[0]) err = statestore.Set(context.TODO(), requestSet[0])
require.Nil(t, err) require.Nil(t, err)
// Validate the set. // Validate the set.
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
assertEquals(t, firstValue, res) assertEquals(t, firstValue, res)
// Second write expect fail // Second write expect fail
err = statestore.Set(requestSet[1]) err = statestore.Set(context.TODO(), requestSet[1])
require.NotNil(t, err) require.NotNil(t, err)
}) })
} }
@ -731,16 +732,16 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
} }
// Delete any potential object, it's important to start from a clean slate. // Delete any potential object, it's important to start from a clean slate.
err := statestore.Delete(&state.DeleteRequest{ err := statestore.Delete(context.TODO(), &state.DeleteRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
err = statestore.Set(request) err = statestore.Set(context.TODO(), request)
require.Nil(t, err) require.Nil(t, err)
// Validate the set. // Validate the set.
res, err := statestore.Get(&state.GetRequest{ res, err := statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
@ -757,11 +758,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
Consistency: state.Strong, Consistency: state.Strong,
}, },
} }
err = statestore.Set(request) err = statestore.Set(context.TODO(), request)
require.Nil(t, err) require.Nil(t, err)
// Validate the set. // Validate the set.
res, err = statestore.Get(&state.GetRequest{ res, err = statestore.Get(context.TODO(), &state.GetRequest{
Key: testKey, Key: testKey,
}) })
require.Nil(t, err) require.Nil(t, err)
@ -771,7 +772,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
request.ETag = etag request.ETag = etag
// Second write expect fail // Second write expect fail
err = statestore.Set(request) err = statestore.Set(context.TODO(), request)
require.NotNil(t, err) require.NotNil(t, err)
}) })
} }