diff --git a/state/aerospike/aerospike.go b/state/aerospike/aerospike.go index 68f85654d..f2beb4fdf 100644 --- a/state/aerospike/aerospike.go +++ b/state/aerospike/aerospike.go @@ -14,6 +14,7 @@ limitations under the License. package aerospike import ( + "context" "encoding/json" "errors" "fmt" @@ -110,7 +111,7 @@ func (aspike *Aerospike) Features() []state.Feature { } // Set stores value for a key to Aerospike. It honors ETag (for concurrency) and consistency settings. -func (aspike *Aerospike) Set(req *state.SetRequest) error { +func (aspike *Aerospike) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -162,7 +163,7 @@ func (aspike *Aerospike) Set(req *state.SetRequest) error { } // Get retrieves state from Aerospike with a key. -func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (aspike *Aerospike) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { asKey, err := as.NewKey(aspike.namespace, aspike.set, req.Key) if err != nil { return nil, err @@ -196,7 +197,7 @@ func (aspike *Aerospike) Get(req *state.GetRequest) (*state.GetResponse, error) } // Delete performs a delete operation. -func (aspike *Aerospike) Delete(req *state.DeleteRequest) error { +func (aspike *Aerospike) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err diff --git a/state/alicloud/tablestore/tablestore.go b/state/alicloud/tablestore/tablestore.go index 6d5583d7c..3e47f1fab 100644 --- a/state/alicloud/tablestore/tablestore.go +++ b/state/alicloud/tablestore/tablestore.go @@ -14,6 +14,7 @@ limitations under the License. package tablestore import ( + "context" "encoding/json" "github.com/agrea/ptr" @@ -68,7 +69,7 @@ func (s *AliCloudTableStore) Features() []state.Feature { return s.features } -func (s *AliCloudTableStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *AliCloudTableStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { criteria := &tablestore.SingleRowQueryCriteria{ PrimaryKey: s.primaryKey(req.Key), TableName: s.metadata.TableName, @@ -103,7 +104,7 @@ func (s *AliCloudTableStore) getResp(columns []*tablestore.AttributeColumn) *sta return getResp } -func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *AliCloudTableStore) BulkGet(ctx context.Context, reqs []state.GetRequest) (bool, []state.BulkGetResponse, error) { // "len == 0": empty request, directly return empty response if len(reqs) == 0 { return true, []state.BulkGetResponse{}, nil @@ -139,7 +140,7 @@ func (s *AliCloudTableStore) BulkGet(reqs []state.GetRequest) (bool, []state.Bul return true, responseList, nil } -func (s *AliCloudTableStore) Set(req *state.SetRequest) error { +func (s *AliCloudTableStore) Set(ctx context.Context, req *state.SetRequest) error { change := s.updateRowChange(req) request := &tablestore.UpdateRowRequest{ @@ -183,7 +184,7 @@ func unmarshal(val interface{}) []byte { return []byte(output) } -func (s *AliCloudTableStore) Delete(req *state.DeleteRequest) error { +func (s *AliCloudTableStore) Delete(ctx context.Context, req *state.DeleteRequest) error { change := s.deleteRowChange(req) deleteRowReq := &tablestore.DeleteRowRequest{ @@ -205,15 +206,15 @@ func (s *AliCloudTableStore) deleteRowChange(req *state.DeleteRequest) *tablesto return change } -func (s *AliCloudTableStore) BulkSet(reqs []state.SetRequest) error { - return s.batchWrite(reqs, nil) +func (s *AliCloudTableStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error { + return s.batchWrite(ctx, reqs, nil) } -func (s *AliCloudTableStore) BulkDelete(reqs []state.DeleteRequest) error { - return s.batchWrite(nil, reqs) +func (s *AliCloudTableStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error { + return s.batchWrite(ctx, nil, reqs) } -func (s *AliCloudTableStore) batchWrite(setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error { +func (s *AliCloudTableStore) batchWrite(ctx context.Context, setReqs []state.SetRequest, deleteReqs []state.DeleteRequest) error { bathReq := &tablestore.BatchWriteRowRequest{ IsAtomic: true, } diff --git a/state/alicloud/tablestore/tablestore_test.go b/state/alicloud/tablestore/tablestore_test.go index d946a7be2..4a6459027 100644 --- a/state/alicloud/tablestore/tablestore_test.go +++ b/state/alicloud/tablestore/tablestore_test.go @@ -14,6 +14,7 @@ limitations under the License. package tablestore import ( + "context" "testing" "github.com/agrea/ptr" @@ -63,7 +64,7 @@ func TestReadAndWrite(t *testing.T) { Value: "value of key", ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) }) @@ -71,7 +72,7 @@ func TestReadAndWrite(t *testing.T) { getReq := &state.GetRequest{ Key: "theFirstKey", } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, "value of key", string(resp.Data)) @@ -83,7 +84,7 @@ func TestReadAndWrite(t *testing.T) { Value: "1234", ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) }) @@ -91,14 +92,14 @@ func TestReadAndWrite(t *testing.T) { getReq := &state.GetRequest{ Key: "theSecondKey", } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, "1234", string(resp.Data)) }) t.Run("test BulkSet", func(t *testing.T) { - err := store.BulkSet([]state.SetRequest{{ + err := store.BulkSet(context.TODO(), []state.SetRequest{{ Key: "theFirstKey", Value: "666", }, { @@ -110,7 +111,7 @@ func TestReadAndWrite(t *testing.T) { }) t.Run("test BulkGet", func(t *testing.T) { - _, resp, err := store.BulkGet([]state.GetRequest{{ + _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{ Key: "theFirstKey", }, { Key: "theSecondKey", @@ -126,12 +127,12 @@ func TestReadAndWrite(t *testing.T) { req := &state.DeleteRequest{ Key: "theFirstKey", } - err := store.Delete(req) + err := store.Delete(context.TODO(), req) assert.Nil(t, err) }) t.Run("test BulkGet2", func(t *testing.T) { - _, resp, err := store.BulkGet([]state.GetRequest{{ + _, resp, err := store.BulkGet(context.TODO(), []state.GetRequest{{ Key: "theFirstKey", }, { Key: "theSecondKey", diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index c163a7d80..3b705cfab 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -14,6 +14,7 @@ limitations under the License. package dynamodb import ( + "context" "crypto/rand" "encoding/binary" "encoding/json" @@ -79,7 +80,7 @@ func (d *StateStore) Features() []state.Feature { } // Get retrieves a dynamoDB item. -func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { input := &dynamodb.GetItemInput{ ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong), TableName: aws.String(d.table), @@ -90,7 +91,7 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, } - result, err := d.client.GetItem(input) + result, err := d.client.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -133,13 +134,13 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulk get operations. -func (d *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (d *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with dynamodb.BatchGetItem for performance return false, nil, nil } // Set saves a dynamoDB item. -func (d *StateStore) Set(req *state.SetRequest) error { +func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { item, err := d.getItemFromReq(req) if err != nil { return err @@ -165,7 +166,7 @@ func (d *StateStore) Set(req *state.SetRequest) error { input.ConditionExpression = &condExpr } - _, err = d.client.PutItem(input) + _, err = d.client.PutItemWithContext(ctx, input) if err != nil && haveEtag { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -177,11 +178,11 @@ func (d *StateStore) Set(req *state.SetRequest) error { } // BulkSet performs a bulk set operation. -func (d *StateStore) BulkSet(req []state.SetRequest) error { +func (d *StateStore) BulkSet(ctx context.Context, req []state.SetRequest) error { writeRequests := []*dynamodb.WriteRequest{} if len(req) == 1 { - return d.Set(&req[0]) + return d.Set(ctx, &req[0]) } for _, r := range req { @@ -210,7 +211,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error { requestItems := map[string][]*dynamodb.WriteRequest{} requestItems[d.table] = writeRequests - _, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ + _, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{ RequestItems: requestItems, }) @@ -218,7 +219,7 @@ func (d *StateStore) BulkSet(req []state.SetRequest) error { } // Delete performs a delete operation. -func (d *StateStore) Delete(req *state.DeleteRequest) error { +func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { input := &dynamodb.DeleteItemInput{ Key: map[string]*dynamodb.AttributeValue{ "key": { @@ -238,7 +239,7 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error { input.ExpressionAttributeValues = exprAttrValues } - _, err := d.client.DeleteItem(input) + _, err := d.client.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -250,11 +251,11 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error { } // BulkDelete performs a bulk delete operation. -func (d *StateStore) BulkDelete(req []state.DeleteRequest) error { +func (d *StateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { writeRequests := []*dynamodb.WriteRequest{} if len(req) == 1 { - return d.Delete(&req[0]) + return d.Delete(ctx, &req[0]) } for _, r := range req { @@ -277,7 +278,7 @@ func (d *StateStore) BulkDelete(req []state.DeleteRequest) error { requestItems := map[string][]*dynamodb.WriteRequest{} requestItems[d.table] = writeRequests - _, e := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ + _, e := d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{ RequestItems: requestItems, }) diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 73c4d8745..2bc6fd4c7 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -15,12 +15,14 @@ limitations under the License. package dynamodb import ( + "context" "fmt" "strconv" "testing" "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" @@ -31,10 +33,10 @@ import ( ) type mockedDynamoDB struct { - GetItemFn func(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) - PutItemFn func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) - DeleteItemFn func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) - BatchWriteItemFn func(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) + GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) + PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) + DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) + BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) dynamodbiface.DynamoDBAPI } @@ -44,20 +46,20 @@ type DynamoDBItem struct { TestAttributeName int64 `json:"testAttributeName"` } -func (m *mockedDynamoDB) GetItem(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { - return m.GetItemFn(input) +func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return m.GetItemWithContextFn(ctx, input, op...) } -func (m *mockedDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { - return m.PutItemFn(input) +func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { + return m.PutItemWithContextFn(ctx, input, op...) } -func (m *mockedDynamoDB) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { - return m.DeleteItemFn(input) +func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { + return m.DeleteItemWithContextFn(ctx, input, op...) } -func (m *mockedDynamoDB) BatchWriteItem(input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) { - return m.BatchWriteItemFn(input) +func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { + return m.BatchWriteItemWithContextFn(ctx, input, op...) } func TestInit(t *testing.T) { @@ -99,7 +101,7 @@ func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ "key": { @@ -123,7 +125,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -131,7 +133,7 @@ func TestGet(t *testing.T) { t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ "key": { @@ -159,7 +161,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -167,7 +169,7 @@ func TestGet(t *testing.T) { t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ "key": { @@ -195,7 +197,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) @@ -203,7 +205,7 @@ func TestGet(t *testing.T) { t.Run("Unsuccessfully get item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return nil, fmt.Errorf("failed to retrieve data") }, }, @@ -215,14 +217,14 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.NotNil(t, err) assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{}, }, nil @@ -236,7 +238,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) @@ -244,7 +246,7 @@ func TestGet(t *testing.T) { t.Run("Unsuccessfully with no required key", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - GetItemFn: func(input *dynamodb.GetItemInput) (output *dynamodb.GetItemOutput, err error) { + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ "value2": { @@ -262,7 +264,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(req) + out, err := ss.Get(context.TODO(), req) assert.Nil(t, err) assert.Empty(t, out.Data) }) @@ -276,7 +278,7 @@ func TestSet(t *testing.T) { t.Run("Successfully set item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), }, *input.Item["key"]) @@ -301,14 +303,14 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set item with matching etag", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), }, *input.Item["key"]) @@ -339,14 +341,14 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), }, *input.Item["key"]) @@ -373,7 +375,7 @@ func TestSet(t *testing.T) { }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -386,7 +388,7 @@ func TestSet(t *testing.T) { t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), }, *input.Item["key"]) @@ -415,14 +417,14 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), }, *input.Item["key"]) @@ -446,7 +448,7 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) switch err.(type) { case *state.ETagError: @@ -458,7 +460,7 @@ func TestSet(t *testing.T) { t.Run("Successfully set item with ttl = -1", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, len(input.Item), 4) result := DynamoDBItem{} dynamodbattribute.UnmarshalMap(input.Item, &result) @@ -487,13 +489,13 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, len(input.Item), 4) result := DynamoDBItem{} dynamodbattribute.UnmarshalMap(input.Item, &result) @@ -522,14 +524,14 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, fmt.Errorf("unable to put item") }, }, @@ -540,13 +542,13 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), }, *input.Item["key"]) @@ -574,13 +576,13 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { S: aws.String("somekey"), @@ -613,7 +615,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(req) + err := ss.Set(context.TODO(), req) assert.NotNil(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -628,7 +630,7 @@ func TestBulkSet(t *testing.T) { tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { expected := map[string][]*dynamodb.WriteRequest{} expected[tableName] = []*dynamodb.WriteRequest{ { @@ -688,14 +690,14 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set items with ttl = -1", func(t *testing.T) { tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { expected := map[string][]*dynamodb.WriteRequest{} expected[tableName] = []*dynamodb.WriteRequest{ { @@ -761,14 +763,14 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Successfully set items with ttl", func(t *testing.T) { tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { expected := map[string][]*dynamodb.WriteRequest{} // This might fail occasionally due to timestamp precision. timestamp := time.Now().Unix() + 90 @@ -836,13 +838,13 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully set items", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { return nil, fmt.Errorf("unable to bulk write items") }, }, @@ -861,7 +863,7 @@ func TestBulkSet(t *testing.T) { }, }, } - err := ss.BulkSet(req) + err := ss.BulkSet(context.TODO(), req) assert.NotNil(t, err) }) } @@ -874,7 +876,7 @@ func TestDelete(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { S: aws.String(req.Key), @@ -885,7 +887,7 @@ func TestDelete(t *testing.T) { }, }, } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.Nil(t, err) }) @@ -898,7 +900,7 @@ func TestDelete(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { S: aws.String(req.Key), @@ -913,7 +915,7 @@ func TestDelete(t *testing.T) { }, }, } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.Nil(t, err) }) @@ -926,7 +928,7 @@ func TestDelete(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { S: aws.String(req.Key), @@ -942,7 +944,7 @@ func TestDelete(t *testing.T) { }, }, } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.NotNil(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -955,7 +957,7 @@ func TestDelete(t *testing.T) { t.Run("Unsuccessfully delete item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { return nil, fmt.Errorf("unable to delete item") }, }, @@ -963,7 +965,7 @@ func TestDelete(t *testing.T) { req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(req) + err := ss.Delete(context.TODO(), req) assert.NotNil(t, err) }) } @@ -973,7 +975,7 @@ func TestBulkDelete(t *testing.T) { tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { expected := map[string][]*dynamodb.WriteRequest{} expected[tableName] = []*dynamodb.WriteRequest{ { @@ -1012,13 +1014,13 @@ func TestBulkDelete(t *testing.T) { Key: "key2", }, } - err := ss.BulkDelete(req) + err := ss.BulkDelete(context.TODO(), req) assert.Nil(t, err) }) t.Run("Unsuccessfully delete items", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ - BatchWriteItemFn: func(input *dynamodb.BatchWriteItemInput) (output *dynamodb.BatchWriteItemOutput, err error) { + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { return nil, fmt.Errorf("unable to bulk write items") }, }, @@ -1031,7 +1033,7 @@ func TestBulkDelete(t *testing.T) { Key: "key", }, } - err := ss.BulkDelete(req) + err := ss.BulkDelete(context.TODO(), req) assert.NotNil(t, err) }) } diff --git a/state/azure/blobstorage/blobstorage.go b/state/azure/blobstorage/blobstorage.go index cb70b95ea..024c8e1fe 100644 --- a/state/azure/blobstorage/blobstorage.go +++ b/state/azure/blobstorage/blobstorage.go @@ -130,21 +130,21 @@ func (r *StateStore) Features() []state.Feature { } // Delete the state. -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("delete %s", req.Key) - return r.deleteFile(context.Background(), req) + return r.deleteFile(ctx, req) } // Get the state. -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("get %s", req.Key) - return r.readFile(context.Background(), req) + return r.readFile(ctx, req) } // Set the state. -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s", req.Key) - return r.writeFile(context.Background(), req) + return r.writeFile(ctx, req) } func (r *StateStore) Ping() error { diff --git a/state/azure/cosmosdb/cosmosdb.go b/state/azure/cosmosdb/cosmosdb.go index 3caf4b04e..4369a19cb 100644 --- a/state/azure/cosmosdb/cosmosdb.go +++ b/state/azure/cosmosdb/cosmosdb.go @@ -202,7 +202,7 @@ func (c *StateStore) Features() []state.Feature { } // Get retrieves a CosmosDB item. -func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { partitionKey := populatePartitionMetadata(req.Key, req.Metadata) options := azcosmos.ItemOptions{} @@ -212,9 +212,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() } - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options) - cancel() if err != nil { var responseErr *azcore.ResponseError if errors.As(err, &responseErr) && responseErr.ErrorCode == "NotFound" { @@ -263,7 +261,7 @@ func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves a CosmosDB item. -func (c *StateStore) Set(req *state.SetRequest) error { +func (c *StateStore) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -301,10 +299,8 @@ func (c *StateStore) Set(req *state.SetRequest) error { return err } - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) pk := azcosmos.NewPartitionKeyString(partitionKey) _, err = c.client.UpsertItem(ctx, pk, marsh, &options) - cancel() if err != nil { return err } @@ -312,7 +308,7 @@ func (c *StateStore) Set(req *state.SetRequest) error { } // Delete performs a delete operation. -func (c *StateStore) Delete(req *state.DeleteRequest) error { +func (c *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -330,10 +326,8 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error { options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() } - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) pk := azcosmos.NewPartitionKeyString(partitionKey) _, err = c.client.DeleteItem(ctx, pk, req.Key, &options) - cancel() if err != nil && !isNotFoundError(err) { c.logger.Debugf("Error from cosmos.DeleteDocument e=%e, e.Error=%s", err, err.Error()) if req.ETag != nil && *req.ETag != "" { @@ -346,7 +340,7 @@ func (c *StateStore) Delete(req *state.DeleteRequest) error { } // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. -func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error) { +func (c *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) (err error) { if len(request.Operations) == 0 { c.logger.Debugf("No Operations Provided") return nil @@ -413,9 +407,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error) c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey) - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) batchResponse, err := c.client.ExecuteTransactionalBatch(ctx, batch, nil) - cancel() if err != nil { return err } @@ -440,7 +432,7 @@ func (c *StateStore) Multi(request *state.TransactionalStateRequest) (err error) return nil } -func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (c *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { q := &Query{} qbuilder := query.NewQueryBuilder(q) @@ -448,7 +440,7 @@ func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error return &state.QueryResponse{}, err } - data, token, err := q.execute(c.client) + data, token, err := q.execute(ctx, c.client) if err != nil { return nil, err } diff --git a/state/azure/cosmosdb/cosmosdb_query.go b/state/azure/cosmosdb/cosmosdb_query.go index 7506f0659..da4da4c13 100644 --- a/state/azure/cosmosdb/cosmosdb_query.go +++ b/state/azure/cosmosdb/cosmosdb_query.go @@ -144,7 +144,7 @@ func (q *Query) setNextParameter(val string) string { return pname } -func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) { +func (q *Query) execute(ctx context.Context, client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) { opts := &azcosmos.QueryOptions{} resultLimit := q.limit @@ -160,9 +160,7 @@ func (q *Query) execute(client *azcosmos.ContainerClient) ([]state.QueryItem, st token := "" for queryPager.More() { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) queryResponse, innerErr := queryPager.NextPage(ctx) - cancel() if innerErr != nil { return nil, "", innerErr } diff --git a/state/azure/tablestorage/tablestorage.go b/state/azure/tablestorage/tablestorage.go index 89dede5fc..33056120d 100644 --- a/state/azure/tablestorage/tablestorage.go +++ b/state/azure/tablestorage/tablestorage.go @@ -163,10 +163,10 @@ func (r *StateStore) Features() []state.Feature { return r.features } -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("delete %s", req.Key) - err := r.deleteRow(req) + err := r.deleteRow(ctx, req) if err != nil { if req.ETag != nil { return state.NewETagError(state.ETagMismatch, err) @@ -179,12 +179,10 @@ func (r *StateStore) Delete(req *state.DeleteRequest) error { return err } -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("fetching %s", req.Key) pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode) - getContext, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - resp, err := r.client.GetEntity(getContext, pk, rk, nil) + resp, err := r.client.GetEntity(ctx, pk, rk, nil) if err != nil { if isNotFoundError(err) { return &state.GetResponse{}, nil @@ -200,10 +198,10 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, err } -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s", req.Key) - err := r.writeRow(req) + err := r.writeRow(ctx, req) return err } @@ -254,30 +252,26 @@ func getTablesMetadata(metadata map[string]string) (*tablesMetadata, error) { return &meta, nil } -func (r *StateStore) writeRow(req *state.SetRequest) error { +func (r *StateStore) writeRow(ctx context.Context, req *state.SetRequest) error { marshalledEntity, err := r.marshal(req) if err != nil { return err } - writeContext, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() // InsertOrReplace does not support ETag concurrency, therefore we will use Insert to check for key existence // and then use Update to update the key if it exists with the specified ETag - _, err = r.client.AddEntity(writeContext, marshalledEntity, nil) + _, err = r.client.AddEntity(ctx, marshalledEntity, nil) if err != nil { // If Insert failed because item already exists, try to Update instead per Upsert semantics if isEntityAlreadyExistsError(err) { - updateContext, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() // Always Update using the etag when provided even if Concurrency != FirstWrite. // Today the presence of etag takes precedence over Concurrency. // In the future #2739 will impose a breaking change which must disallow the use of etag when not using FirstWrite. if req.ETag != nil && *req.ETag != "" { etag := azcore.ETag(*req.ETag) - _, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{ + _, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{ IfMatch: &etag, UpdateMode: aztables.UpdateModeReplace, }) @@ -295,7 +289,7 @@ func (r *StateStore) writeRow(req *state.SetRequest) error { return state.NewETagError(state.ETagMismatch, errors.New("update with Concurrency.FirstWrite without ETag")) } else { // Finally, last write semantics without ETag should always perform a force update. - _, uerr := r.client.UpdateEntity(updateContext, marshalledEntity, &aztables.UpdateEntityOptions{ + _, uerr := r.client.UpdateEntity(ctx, marshalledEntity, &aztables.UpdateEntityOptions{ IfMatch: nil, // this is the same as "*" matching all ETags UpdateMode: aztables.UpdateModeReplace, }) @@ -336,19 +330,16 @@ func isTableAlreadyExistsError(err error) bool { return false } -func (r *StateStore) deleteRow(req *state.DeleteRequest) error { +func (r *StateStore) deleteRow(ctx context.Context, req *state.DeleteRequest) error { pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode) - deleteContext, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if req.ETag != nil { azcoreETag := azcore.ETag(*req.ETag) - _, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag}) + _, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &azcoreETag}) return err } all := azcore.ETagAny - _, err := r.client.DeleteEntity(deleteContext, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all}) + _, err := r.client.DeleteEntity(ctx, pk, rk, &aztables.DeleteEntityOptions{IfMatch: &all}) return err } diff --git a/state/cassandra/cassandra.go b/state/cassandra/cassandra.go index 0fe9a0f2c..4a8c49c98 100644 --- a/state/cassandra/cassandra.go +++ b/state/cassandra/cassandra.go @@ -14,6 +14,7 @@ limitations under the License. package cassandra import ( + "context" "errors" "fmt" "strconv" @@ -230,12 +231,12 @@ func getCassandraMetadata(metadata state.Metadata) (*cassandraMetadata, error) { } // Delete performs a delete operation. -func (c *Cassandra) Delete(req *state.DeleteRequest) error { - return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).Exec() +func (c *Cassandra) Delete(ctx context.Context, req *state.DeleteRequest) error { + return c.session.Query(fmt.Sprintf("DELETE FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Exec() } // Get retrieves state from cassandra with a key. -func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *Cassandra) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { session := c.session if req.Options.Consistency == state.Strong { @@ -254,7 +255,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { session = sess } - results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).Iter().SliceMap() + results, err := session.Query(fmt.Sprintf("SELECT value FROM %s WHERE key = ?", c.table), req.Key).WithContext(ctx).Iter().SliceMap() if err != nil { return nil, err } @@ -269,7 +270,7 @@ func (c *Cassandra) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves state into cassandra. -func (c *Cassandra) Set(req *state.SetRequest) error { +func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error { var bt []byte b, ok := req.Value.([]byte) if ok { @@ -302,10 +303,10 @@ func (c *Cassandra) Set(req *state.SetRequest) error { } if ttl != nil { - return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).Exec() + return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?) USING TTL ?", c.table), req.Key, bt, *ttl).WithContext(ctx).Exec() } - return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).Exec() + return session.Query(fmt.Sprintf("INSERT INTO %s (key, value) VALUES (?, ?)", c.table), req.Key, bt).WithContext(ctx).Exec() } func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session, error) { diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index f0c46c6b2..e355dccfa 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -14,6 +14,8 @@ limitations under the License. package cockroachdb import ( + "context" + "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) @@ -53,18 +55,18 @@ func (c *CockroachDB) Features() []state.Feature { } // Delete removes an entity from the store. -func (c *CockroachDB) Delete(req *state.DeleteRequest) error { - return c.dbaccess.Delete(req) +func (c *CockroachDB) Delete(ctx context.Context, req *state.DeleteRequest) error { + return c.dbaccess.Delete(ctx, req) } // Get returns an entity from store. -func (c *CockroachDB) Get(req *state.GetRequest) (*state.GetResponse, error) { - return c.dbaccess.Get(req) +func (c *CockroachDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return c.dbaccess.Get(ctx, req) } // Set adds/updates an entity on store. -func (c *CockroachDB) Set(req *state.SetRequest) error { - return c.dbaccess.Set(req) +func (c *CockroachDB) Set(ctx context.Context, req *state.SetRequest) error { + return c.dbaccess.Set(ctx, req) } // Ping checks if database is available. @@ -73,29 +75,29 @@ func (c *CockroachDB) Ping() error { } // BulkDelete removes multiple entries from the store. -func (c *CockroachDB) BulkDelete(req []state.DeleteRequest) error { - return c.dbaccess.BulkDelete(req) +func (c *CockroachDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return c.dbaccess.BulkDelete(ctx, req) } // BulkGet performs a bulks get operations. -func (c *CockroachDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (c *CockroachDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance. return false, nil, nil } // BulkSet adds/updates multiple entities on store. -func (c *CockroachDB) BulkSet(req []state.SetRequest) error { - return c.dbaccess.BulkSet(req) +func (c *CockroachDB) BulkSet(ctx context.Context, req []state.SetRequest) error { + return c.dbaccess.BulkSet(ctx, req) } // Multi handles multiple transactions. Implements TransactionalStore. -func (c *CockroachDB) Multi(request *state.TransactionalStateRequest) error { - return c.dbaccess.ExecuteMulti(request) +func (c *CockroachDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { + return c.dbaccess.ExecuteMulti(ctx, request) } // Query executes a query against store. -func (c *CockroachDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) { - return c.dbaccess.Query(req) +func (c *CockroachDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { + return c.dbaccess.Query(ctx, req) } // Close implements io.Closer. diff --git a/state/cockroachdb/cockroachdb_access.go b/state/cockroachdb/cockroachdb_access.go index 9d85db381..c71222df8 100644 --- a/state/cockroachdb/cockroachdb_access.go +++ b/state/cockroachdb/cockroachdb_access.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -96,12 +97,12 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (p *cockroachDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) +func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, p.setValue, req) } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { +func (p *cockroachDBAccess) setValue(ctx context.Context, req *state.SetRequest) error { p.logger.Debug("Setting state value in CockroachDB") value, isBinary, err := validateAndReturnValue(req) @@ -114,7 +115,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { // Sprintf is required for table name because sql.DB does not substitute parameters for table names. // Other parameters use sql.DB parameter substitution. if req.ETag == nil { - result, err = p.db.Exec(fmt.Sprintf( + result, err = p.db.ExecContext(ctx, fmt.Sprintf( `INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1) ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`, tableName), req.Key, value, isBinary) @@ -127,7 +128,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { etag := uint32(etag64) // When an etag is provided do an update - no insert. - result, err = p.db.Exec(fmt.Sprintf( + result, err = p.db.ExecContext(ctx, fmt.Sprintf( `UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1 WHERE key = $3 AND etag = $4;`, tableName), value, isBinary, req.Key, etag) @@ -149,7 +150,7 @@ func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { return nil } -func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { +func (p *cockroachDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error { p.logger.Debug("Executing BulkSet request") tx, err := p.db.Begin() if err != nil { @@ -159,7 +160,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for _, s := range req { sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Set(&sa) + err = p.Set(ctx, &sa) if err != nil { tx.Rollback() @@ -174,7 +175,7 @@ func (p *cockroachDBAccess) BulkSet(req []state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { p.logger.Debug("Getting state value from CockroachDB") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -183,7 +184,7 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro var value string var isBinary bool var etag int - err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) + err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) if err != nil { // If no rows exist, return an empty response, otherwise return the error. if errors.Is(err, sql.ErrNoRows) { @@ -222,12 +223,12 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro } // Delete removes an item from the state store. -func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) +func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, p.deleteValue, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { +func (p *cockroachDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { p.logger.Debug("Deleting state value from CockroachDB") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -237,7 +238,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { var err error if req.ETag == nil { - result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) + result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key) } else { var etag64 uint64 etag64, err = strconv.ParseUint(*req.ETag, 10, 32) @@ -246,7 +247,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { } etag := uint32(etag64) - result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag) + result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag) } if err != nil { @@ -265,7 +266,7 @@ func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { +func (p *cockroachDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { p.logger.Debug("Executing BulkDelete request") tx, err := p.db.Begin() if err != nil { @@ -275,7 +276,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Delete(&da) + err = p.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -289,7 +290,7 @@ func (p *cockroachDBAccess) BulkDelete(req []state.DeleteRequest) error { return err } -func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { +func (p *cockroachDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error { p.logger.Debug("Executing PostgreSQL transaction") tx, err := p.db.Begin() @@ -308,7 +309,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques return err } - err = p.Set(&setReq) + err = p.Set(ctx, &setReq) if err != nil { tx.Rollback() return err @@ -323,7 +324,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques return err } - err = p.Delete(&delReq) + err = p.Delete(ctx, &delReq) if err != nil { tx.Rollback() return err @@ -341,7 +342,7 @@ func (p *cockroachDBAccess) ExecuteMulti(request *state.TransactionalStateReques } // Query executes a query against store. -func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (p *cockroachDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { p.logger.Debug("Getting query value from CockroachDB") stateQuery := &Query{ @@ -361,7 +362,7 @@ func (p *cockroachDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse p.logger.Debug("Query: " + stateQuery.query) - data, token, err := stateQuery.execute(p.logger, p.db) + data, token, err := stateQuery.execute(ctx, p.logger, p.db) if err != nil { return &state.QueryResponse{ Results: []state.QueryItem{}, diff --git a/state/cockroachdb/cockroachdb_access_test.go b/state/cockroachdb/cockroachdb_access_test.go index e59fb8c86..c6eb83107 100644 --- a/state/cockroachdb/cockroachdb_access_test.go +++ b/state/cockroachdb/cockroachdb_access_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "testing" @@ -109,7 +110,7 @@ func TestMultiWithNoRequests(t *testing.T) { var operations []state.TransactionalStateOperation // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -133,7 +134,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -158,7 +159,7 @@ func TestValidSetRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -182,7 +183,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -206,7 +207,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -231,7 +232,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -255,7 +256,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -279,7 +280,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -311,7 +312,7 @@ func TestMultiOperationOrder(t *testing.T) { ) // Act - err := m.roachDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.roachDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -334,7 +335,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -356,7 +357,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -379,7 +380,7 @@ func TestValidBulkSet(t *testing.T) { }) // Act - err := m.roachDba.BulkSet(sets) + err := m.roachDba.BulkSet(context.TODO(), sets) // Assert assert.Nil(t, err) @@ -400,7 +401,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) { }) // Act - err := m.roachDba.BulkDelete(deletes) + err := m.roachDba.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err) @@ -422,7 +423,7 @@ func TestValidBulkDelete(t *testing.T) { }) // Act - err := m.roachDba.BulkDelete(deletes) + err := m.roachDba.BulkDelete(context.TODO(), deletes) // Assert assert.Nil(t, err) diff --git a/state/cockroachdb/cockroachdb_integration_test.go b/state/cockroachdb/cockroachdb_integration_test.go index cbe82a384..dfd2df82a 100644 --- a/state/cockroachdb/cockroachdb_integration_test.go +++ b/state/cockroachdb/cockroachdb_integration_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "encoding/json" "fmt" @@ -211,7 +212,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -239,7 +240,7 @@ func multiWithSetOnly(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -280,7 +281,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -341,7 +342,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *CockroachDB) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, Metadata: nil, }) @@ -376,7 +377,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -392,7 +393,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *CockroachDB) { Consistency: "", }, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -415,7 +416,7 @@ func newItemWithEtagFails(t *testing.T, pgs *CockroachDB) { ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -449,7 +450,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *CockroachDB) { }, ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -501,7 +502,7 @@ func getItemWithNoKey(t *testing.T, pgs *CockroachDB) { }, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -544,7 +545,7 @@ func setItemWithNoKey(t *testing.T, pgs *CockroachDB) { ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -563,7 +564,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) { }, } - err := pgs.BulkSet(setReq) + err := pgs.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -577,7 +578,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *CockroachDB) { }, } - err = pgs.BulkDelete(deleteReq) + err = pgs.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -644,7 +645,7 @@ func setItem(t *testing.T, pgs *CockroachDB, key string, value interface{}, etag ContentType: nil, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists) @@ -661,7 +662,7 @@ func getItem(t *testing.T, pgs *CockroachDB, key string) (*state.GetResponse, *f Metadata: map[string]string{}, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{ @@ -685,7 +686,7 @@ func deleteItem(t *testing.T, pgs *CockroachDB, key string, etag *string) { Metadata: map[string]string{}, } - deleteErr := pgs.Delete(deleteReq) + deleteErr := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key)) } diff --git a/state/cockroachdb/cockroachdb_query.go b/state/cockroachdb/cockroachdb_query.go index 5fffb006a..78593f9a2 100644 --- a/state/cockroachdb/cockroachdb_query.go +++ b/state/cockroachdb/cockroachdb_query.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "database/sql" "fmt" "strconv" @@ -133,8 +134,8 @@ func (q *Query) Finalize(filters string, storeQuery *query.Query) error { return nil } -func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { - rows, err := db.Query(q.query, q.params...) +func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { + rows, err := db.QueryContext(ctx, q.query, q.params...) if err != nil { return nil, "", fmt.Errorf("query executes '%s' failed: %w", q.query, err) } diff --git a/state/cockroachdb/cockroachdb_test.go b/state/cockroachdb/cockroachdb_test.go index ddf877d3d..2b3664ecd 100644 --- a/state/cockroachdb/cockroachdb_test.go +++ b/state/cockroachdb/cockroachdb_test.go @@ -14,6 +14,7 @@ limitations under the License. package cockroachdb import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -42,37 +43,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { m.deleteExecuted = true return nil } -func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { +func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error { return nil } -func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { +func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error { return nil } -func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { return nil, nil } diff --git a/state/cockroachdb/dbaccess.go b/state/cockroachdb/dbaccess.go index d89af8a54..112b64bc0 100644 --- a/state/cockroachdb/dbaccess.go +++ b/state/cockroachdb/dbaccess.go @@ -13,18 +13,22 @@ limitations under the License. package cockroachdb -import "github.com/dapr/components-contrib/state" +import ( + "context" + + "github.com/dapr/components-contrib/state" +) // dbAccess is a private interface which enables unit testing of CockroachDB. type dbAccess interface { Init(metadata state.Metadata) error - Set(req *state.SetRequest) error - BulkSet(req []state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - BulkDelete(req []state.DeleteRequest) error - ExecuteMulti(req *state.TransactionalStateRequest) error - Query(req *state.QueryRequest) (*state.QueryResponse, error) + Set(ctx context.Context, req *state.SetRequest) error + BulkSet(ctx context.Context, req []state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + BulkDelete(ctx context.Context, req []state.DeleteRequest) error + ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error + Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) Ping() error Close() error } diff --git a/state/couchbase/couchbase.go b/state/couchbase/couchbase.go index d3e1a14a3..538dc6acc 100644 --- a/state/couchbase/couchbase.go +++ b/state/couchbase/couchbase.go @@ -14,6 +14,7 @@ limitations under the License. package couchbase import ( + "context" "errors" "fmt" "strconv" @@ -144,7 +145,7 @@ func (cbs *Couchbase) Features() []state.Feature { } // Set stores value for a key to couchbase. It honors ETag (for concurrency) and consistency settings. -func (cbs *Couchbase) Set(req *state.SetRequest) error { +func (cbs *Couchbase) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -188,7 +189,7 @@ func (cbs *Couchbase) Set(req *state.SetRequest) error { } // Get retrieves state from couchbase with a key. -func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (cbs *Couchbase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { var data interface{} cas, err := cbs.bucket.Get(req.Key, &data) if err != nil { @@ -206,7 +207,7 @@ func (cbs *Couchbase) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Delete performs a delete operation. -func (cbs *Couchbase) Delete(req *state.DeleteRequest) error { +func (cbs *Couchbase) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err diff --git a/state/gcp/firestore/firestore.go b/state/gcp/firestore/firestore.go index c88ee047d..9a37d7ee5 100644 --- a/state/gcp/firestore/firestore.go +++ b/state/gcp/firestore/firestore.go @@ -93,12 +93,12 @@ func (f *Firestore) Features() []state.Feature { } // Get retrieves state from Firestore with a key (Always strong consistency). -func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (f *Firestore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { key := req.Key entityKey := datastore.NameKey(f.entityKind, key, nil) var entity StateEntity - err := f.client.Get(context.Background(), entityKey, &entity) + err := f.client.Get(ctx, entityKey, &entity) if err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { return nil, err @@ -111,7 +111,7 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (f *Firestore) setValue(req *state.SetRequest) error { +func (f *Firestore) setValue(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -128,7 +128,6 @@ func (f *Firestore) setValue(req *state.SetRequest) error { entity := &StateEntity{ Value: v, } - ctx := context.Background() key := datastore.NameKey(f.entityKind, req.Key, nil) _, err = f.client.Put(ctx, key, entity) @@ -141,12 +140,11 @@ func (f *Firestore) setValue(req *state.SetRequest) error { } // Set saves state into Firestore with retry. -func (f *Firestore) Set(req *state.SetRequest) error { - return state.SetWithOptions(f.setValue, req) +func (f *Firestore) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, f.setValue, req) } -func (f *Firestore) deleteValue(req *state.DeleteRequest) error { - ctx := context.Background() +func (f *Firestore) deleteValue(ctx context.Context, req *state.DeleteRequest) error { key := datastore.NameKey(f.entityKind, req.Key, nil) err := f.client.Delete(ctx, key) @@ -158,8 +156,8 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error { } // Delete performs a delete operation. -func (f *Firestore) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(f.deleteValue, req) +func (f *Firestore) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, f.deleteValue, req) } func getFirestoreMetadata(metadata state.Metadata) (*firestoreMetadata, error) { diff --git a/state/hashicorp/consul/consul.go b/state/hashicorp/consul/consul.go index b16e78b2f..5387623d8 100644 --- a/state/hashicorp/consul/consul.go +++ b/state/hashicorp/consul/consul.go @@ -14,6 +14,7 @@ limitations under the License. package consul import ( + "context" "encoding/json" "fmt" @@ -102,11 +103,12 @@ func metadataToConfig(connInfo map[string]string) (*consulConfig, error) { } // Get retrieves a Consul KV item. -func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (c *Consul) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { queryOpts := &api.QueryOptions{} if req.Options.Consistency == state.Strong { queryOpts.RequireConsistent = true } + queryOpts = queryOpts.WithContext(ctx) resp, queryMeta, err := c.client.KV().Get(fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key), queryOpts) if err != nil { @@ -124,7 +126,7 @@ func (c *Consul) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set saves a Consul KV item. -func (c *Consul) Set(req *state.SetRequest) error { +func (c *Consul) Set(ctx context.Context, req *state.SetRequest) error { var reqValByte []byte b, ok := req.Value.([]byte) if ok { @@ -135,10 +137,12 @@ func (c *Consul) Set(req *state.SetRequest) error { keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key) + writeOptions := new(api.WriteOptions) + writeOptions = writeOptions.WithContext(ctx) _, err := c.client.KV().Put(&api.KVPair{ Key: keyWithPath, Value: reqValByte, - }, nil) + }, writeOptions) if err != nil { return fmt.Errorf("couldn't set key %s: %s", keyWithPath, err) } @@ -147,9 +151,11 @@ func (c *Consul) Set(req *state.SetRequest) error { } // Delete performes a Consul KV delete operation. -func (c *Consul) Delete(req *state.DeleteRequest) error { +func (c *Consul) Delete(ctx context.Context, req *state.DeleteRequest) error { keyWithPath := fmt.Sprintf("%s/%s", c.keyPrefixPath, req.Key) - _, err := c.client.KV().Delete(keyWithPath, nil) + writeOptions := new(api.WriteOptions) + writeOptions = writeOptions.WithContext(ctx) + _, err := c.client.KV().Delete(keyWithPath, writeOptions) if err != nil { return fmt.Errorf("couldn't delete key %s: %s", keyWithPath, err) } diff --git a/state/hazelcast/hazelcast.go b/state/hazelcast/hazelcast.go index 1c8043726..47a30134b 100644 --- a/state/hazelcast/hazelcast.go +++ b/state/hazelcast/hazelcast.go @@ -14,6 +14,7 @@ limitations under the License. package hazelcast import ( + "context" "errors" "fmt" "strings" @@ -91,7 +92,7 @@ func (store *Hazelcast) Features() []state.Feature { } // Set stores value for a key to Hazelcast. -func (store *Hazelcast) Set(req *state.SetRequest) error { +func (store *Hazelcast) Set(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req) if err != nil { return err @@ -117,7 +118,7 @@ func (store *Hazelcast) Set(req *state.SetRequest) error { } // Get retrieves state from Hazelcast with a key. -func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (store *Hazelcast) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { resp, err := store.hzMap.Get(req.Key) if err != nil { return nil, fmt.Errorf("hazelcast error: failed to get value for %s: %s", req.Key, err) @@ -138,7 +139,7 @@ func (store *Hazelcast) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Delete performs a delete operation. -func (store *Hazelcast) Delete(req *state.DeleteRequest) error { +func (store *Hazelcast) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err diff --git a/state/in-memory/in_memory.go b/state/in-memory/in_memory.go index 96e86c1a1..21907fc1b 100644 --- a/state/in-memory/in_memory.go +++ b/state/in-memory/in_memory.go @@ -73,7 +73,7 @@ func (store *inMemoryStore) Features() []state.Feature { return []state.Feature{state.FeatureETag, state.FeatureTransactional} } -func (store *inMemoryStore) Delete(req *state.DeleteRequest) error { +func (store *inMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest) error { // step1: validate parameters if err := store.doDeleteValidateParameters(req); err != nil { return err @@ -90,7 +90,7 @@ func (store *inMemoryStore) Delete(req *state.DeleteRequest) error { // step3: do really delete // this operation won't fail - store.doDelete(req.Key) + store.doDelete(ctx, req.Key) return nil } @@ -117,11 +117,11 @@ func (store *inMemoryStore) doValidateEtag(key string, etag *string, concurrency return nil } -func (store *inMemoryStore) doDelete(key string) { +func (store *inMemoryStore) doDelete(ctx context.Context, key string) { delete(store.items, key) } -func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error { +func (store *inMemoryStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { if len(req) == 0 { return nil } @@ -148,15 +148,15 @@ func (store *inMemoryStore) BulkDelete(req []state.DeleteRequest) error { // step3: do really delete for _, dr := range req { - store.doDelete(dr.Key) + store.doDelete(ctx, dr.Key) } return nil } -func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, error) { - item := store.doGetWithReadLock(req.Key) +func (store *inMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + item := store.doGetWithReadLock(ctx, req.Key) if item != nil && isExpired(item.expire) { - item = store.doGetWithWriteLock(req.Key) + item = store.doGetWithWriteLock(ctx, req.Key) } if item == nil { @@ -165,14 +165,14 @@ func (store *inMemoryStore) Get(req *state.GetRequest) (*state.GetResponse, erro return &state.GetResponse{Data: unmarshal(item.data), ETag: item.etag}, nil } -func (store *inMemoryStore) doGetWithReadLock(key string) *inMemStateStoreItem { +func (store *inMemoryStore) doGetWithReadLock(ctx context.Context, key string) *inMemStateStoreItem { store.lock.RLock() defer store.lock.RUnlock() return store.items[key] } -func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem { +func (store *inMemoryStore) doGetWithWriteLock(ctx context.Context, key string) *inMemStateStoreItem { store.lock.Lock() defer store.lock.Unlock() // get item and check expired again to avoid if item changed between we got this write-lock @@ -181,7 +181,7 @@ func (store *inMemoryStore) doGetWithWriteLock(key string) *inMemStateStoreItem return nil } if isExpired(item.expire) { - store.doDelete(key) + store.doDelete(ctx, key) return nil } return item @@ -194,11 +194,11 @@ func isExpired(expire int64) bool { return time.Now().UnixMilli() > expire } -func (store *inMemoryStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (store *inMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { return false, nil, nil } -func (store *inMemoryStore) Set(req *state.SetRequest) error { +func (store *inMemoryStore) Set(ctx context.Context, req *state.SetRequest) error { // step1: validate parameters ttlInSeconds, err := store.doSetValidateParameters(req) if err != nil { @@ -217,7 +217,7 @@ func (store *inMemoryStore) Set(req *state.SetRequest) error { // step3: do really set // this operation won't fail - store.doSet(req.Key, b, req.ETag, ttlInSeconds) + store.doSet(ctx, req.Key, b, req.ETag, ttlInSeconds) return nil } @@ -253,7 +253,7 @@ func doParseTTLInSeconds(metadata map[string]string) (int, error) { return i, nil } -func (store *inMemoryStore) doSet(key string, data []byte, etag *string, ttlInSeconds int) { +func (store *inMemoryStore) doSet(ctx context.Context, key string, data []byte, etag *string, ttlInSeconds int) { store.items[key] = &inMemStateStoreItem{ data: data, etag: etag, @@ -268,7 +268,7 @@ type innerSetRequest struct { data []byte } -func (store *inMemoryStore) BulkSet(req []state.SetRequest) error { +func (store *inMemoryStore) BulkSet(ctx context.Context, req []state.SetRequest) error { if len(req) == 0 { return nil } @@ -305,12 +305,12 @@ func (store *inMemoryStore) BulkSet(req []state.SetRequest) error { // step3: do really set // these operations won't fail for _, innerSetRequest := range innerSetRequestList { - store.doSet(innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds) + store.doSet(ctx, innerSetRequest.req.Key, innerSetRequest.data, innerSetRequest.req.ETag, innerSetRequest.ttlInSeconds) } return nil } -func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) error { +func (store *inMemoryStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { if len(request.Operations) == 0 { return nil } @@ -366,10 +366,10 @@ func (store *inMemoryStore) Multi(request *state.TransactionalStateRequest) erro for _, o := range request.Operations { if o.Operation == state.Upsert { s := o.Request.(innerSetRequest) - store.doSet(s.req.Key, s.data, s.req.ETag, s.ttlInSeconds) + store.doSet(ctx, s.req.Key, s.data, s.req.ETag, s.ttlInSeconds) } else if o.Operation == state.Delete { d := o.Request.(state.DeleteRequest) - store.doDelete(d.Key) + store.doDelete(ctx, d.Key) } } return nil @@ -406,7 +406,7 @@ func (store *inMemoryStore) doCleanExpiredItems() { for key, item := range store.items { if isExpired(item.expire) { - store.doDelete(key) + store.doDelete(context.Background(), key) } } } diff --git a/state/in-memory/in_memory_test.go b/state/in-memory/in_memory_test.go index 8a29c31ed..47763f837 100644 --- a/state/in-memory/in_memory_test.go +++ b/state/in-memory/in_memory_test.go @@ -14,6 +14,7 @@ limitations under the License. package inmemory import ( + "context" "testing" "time" @@ -43,13 +44,13 @@ func TestReadAndWrite(t *testing.T) { Value: valueA, ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) // get after set getReq := &state.GetRequest{ Key: keyA, } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, valueA, string(resp.Data)) @@ -62,7 +63,7 @@ func TestReadAndWrite(t *testing.T) { Value: valueA, Metadata: map[string]string{"ttlInSeconds": "1"}, } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) // simulate expiration time.Sleep(2 * time.Second) @@ -70,7 +71,7 @@ func TestReadAndWrite(t *testing.T) { getReq := &state.GetRequest{ Key: keyA, } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Nil(t, resp.Data) @@ -84,20 +85,20 @@ func TestReadAndWrite(t *testing.T) { Value: "1234", ETag: ptr.String("the etag"), } - err := store.Set(setReq) + err := store.Set(context.TODO(), setReq) assert.Nil(t, err) // get getReq := &state.GetRequest{ Key: "theSecondKey", } - resp, err := store.Get(getReq) + resp, err := store.Get(context.TODO(), getReq) assert.Nil(t, err) assert.NotNil(t, resp) assert.Equal(t, "1234", string(resp.Data)) }) t.Run("BulkSet two keys", func(t *testing.T) { - err := store.BulkSet([]state.SetRequest{{ + err := store.BulkSet(context.TODO(), []state.SetRequest{{ Key: "theFirstKey", Value: "666", }, { @@ -109,7 +110,7 @@ func TestReadAndWrite(t *testing.T) { }) t.Run("BulkGet fails when not supported", func(t *testing.T) { - supportBulk, _, err := store.BulkGet([]state.GetRequest{{ + supportBulk, _, err := store.BulkGet(context.TODO(), []state.GetRequest{{ Key: "theFirstKey", }, { Key: "theSecondKey", @@ -123,7 +124,7 @@ func TestReadAndWrite(t *testing.T) { req := &state.DeleteRequest{ Key: "theFirstKey", } - err := store.Delete(req) + err := store.Delete(context.TODO(), req) assert.Nil(t, err) }) } diff --git a/state/jetstream/jetstream.go b/state/jetstream/jetstream.go index 930d56f5f..998ef1333 100644 --- a/state/jetstream/jetstream.go +++ b/state/jetstream/jetstream.go @@ -14,6 +14,7 @@ limitations under the License. package jetstream import ( + "context" "fmt" "strings" @@ -97,7 +98,7 @@ func (js *StateStore) Features() []state.Feature { } // Get retrieves state with a key. -func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (js *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { entry, err := js.bucket.Get(escape(req.Key)) if err != nil { return nil, err @@ -109,14 +110,14 @@ func (js *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Set stores value for a key. -func (js *StateStore) Set(req *state.SetRequest) error { +func (js *StateStore) Set(ctx context.Context, req *state.SetRequest) error { bt, _ := utils.Marshal(req.Value, js.json.Marshal) _, err := js.bucket.Put(escape(req.Key), bt) return err } // Delete performs a delete operation. -func (js *StateStore) Delete(req *state.DeleteRequest) error { +func (js *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { return js.bucket.Delete(escape(req.Key)) } diff --git a/state/jetstream/jetstream_test.go b/state/jetstream/jetstream_test.go index 7d3610c91..ad15adf3c 100644 --- a/state/jetstream/jetstream_test.go +++ b/state/jetstream/jetstream_test.go @@ -14,6 +14,7 @@ limitations under the License. package jetstream import ( + "context" "encoding/json" "fmt" "reflect" @@ -112,7 +113,7 @@ func TestSetGetAndDelete(t *testing.T) { "dkey": "dvalue", } - err = store.Set(&state.SetRequest{ + err = store.Set(context.TODO(), &state.SetRequest{ Key: tkey, Value: tData, }) @@ -121,7 +122,7 @@ func TestSetGetAndDelete(t *testing.T) { return } - resp, err := store.Get(&state.GetRequest{ + resp, err := store.Get(context.TODO(), &state.GetRequest{ Key: tkey, }) if err != nil { @@ -134,7 +135,7 @@ func TestSetGetAndDelete(t *testing.T) { t.Fatal("Response data does not match written data\n") } - err = store.Delete(&state.DeleteRequest{ + err = store.Delete(context.TODO(), &state.DeleteRequest{ Key: tkey, }) if err != nil { @@ -142,7 +143,7 @@ func TestSetGetAndDelete(t *testing.T) { return } - _, err = store.Get(&state.GetRequest{ + _, err = store.Get(context.TODO(), &state.GetRequest{ Key: tkey, }) if err == nil { diff --git a/state/memcached/memcached.go b/state/memcached/memcached.go index 8875ffb66..921cb5ddc 100644 --- a/state/memcached/memcached.go +++ b/state/memcached/memcached.go @@ -14,6 +14,7 @@ limitations under the License. package memcached import ( + "context" "errors" "fmt" "strconv" @@ -139,7 +140,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) { return nil, nil } -func (m *Memcached) setValue(req *state.SetRequest) error { +func (m *Memcached) setValue(ctx context.Context, req *state.SetRequest) error { var bt []byte ttl, err := m.parseTTL(req) if err != nil { @@ -159,7 +160,7 @@ func (m *Memcached) setValue(req *state.SetRequest) error { return nil } -func (m *Memcached) Delete(req *state.DeleteRequest) error { +func (m *Memcached) Delete(ctx context.Context, req *state.DeleteRequest) error { err := m.client.Delete(req.Key) if err != nil { if err == memcache.ErrCacheMiss { @@ -171,7 +172,7 @@ func (m *Memcached) Delete(req *state.DeleteRequest) error { return nil } -func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *Memcached) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { item, err := m.client.Get(req.Key) if err != nil { // Return nil for status 204 @@ -187,6 +188,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (m *Memcached) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) +func (m *Memcached) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, m.setValue, req) } diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index f4abc97e5..f87fa0dff 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -159,10 +159,7 @@ func (m *MongoDB) Features() []state.Feature { } // Set saves state into MongoDB. -func (m *MongoDB) Set(req *state.SetRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) - defer cancel() - +func (m *MongoDB) Set(ctx context.Context, req *state.SetRequest) error { err := m.setInternal(ctx, req) if err != nil { return err @@ -205,12 +202,9 @@ func (m *MongoDB) setInternal(ctx context.Context, req *state.SetRequest) error } // Get retrieves state from MongoDB with a key. -func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *MongoDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { var result Item - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) - defer cancel() - filter := bson.M{id: req.Key} err := m.collection.FindOne(ctx, filter).Decode(&result) if err != nil { @@ -264,10 +258,7 @@ func (m *MongoDB) Get(req *state.GetRequest) (*state.GetResponse, error) { } // Delete performs a delete operation. -func (m *MongoDB) Delete(req *state.DeleteRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) - defer cancel() - +func (m *MongoDB) Delete(ctx context.Context, req *state.DeleteRequest) error { err := m.deleteInternal(ctx, req) if err != nil { return err @@ -294,18 +285,18 @@ func (m *MongoDB) deleteInternal(ctx context.Context, req *state.DeleteRequest) } // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. -func (m *MongoDB) Multi(request *state.TransactionalStateRequest) error { +func (m *MongoDB) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { sess, err := m.client.StartSession() txnOpts := options.Transaction().SetReadConcern(readconcern.Snapshot()). SetWriteConcern(writeconcern.New(writeconcern.WMajority())) - defer sess.EndSession(context.Background()) + defer sess.EndSession(ctx) if err != nil { return fmt.Errorf("error in starting the transaction: %s", err) } - sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) { + sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) { err = m.doTransaction(sessCtx, request.Operations) return nil, err @@ -336,10 +327,7 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state } // Query executes a query against store. -func (m *MongoDB) Query(req *state.QueryRequest) (*state.QueryResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), m.operationTimeout) - defer cancel() - +func (m *MongoDB) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { q := &Query{} qbuilder := query.NewQueryBuilder(q) if err := qbuilder.BuildQuery(&req.Query); err != nil { diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index ca3bbbc76..0278b4dff 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -14,6 +14,7 @@ limitations under the License. package mysql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -279,13 +280,13 @@ func tableExists(db *sql.DB, tableName string) (bool, error) { // Delete removes an entity from the store // Store Interface. -func (m *MySQL) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(m.deleteValue, req) +func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, m.deleteValue, req) } // deleteValue is an internal implementation of delete to enable passing the // logic to state.DeleteWithRetries as a func. -func (m *MySQL) deleteValue(req *state.DeleteRequest) error { +func (m *MySQL) deleteValue(ctx context.Context, req *state.DeleteRequest) error { m.logger.Debug("Deleting state value from MySql") if req.Key == "" { @@ -296,11 +297,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error { var result sql.Result if req.ETag == nil || *req.ETag == "" { - result, err = m.db.Exec(fmt.Sprintf( + result, err = m.db.ExecContext(ctx, fmt.Sprintf( `DELETE FROM %s WHERE id = ?`, m.tableName), req.Key) } else { - result, err = m.db.Exec(fmt.Sprintf( + result, err = m.db.ExecContext(ctx, fmt.Sprintf( `DELETE FROM %s WHERE id = ? and eTag = ?`, m.tableName), req.Key, *req.ETag) } @@ -323,7 +324,7 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error { // BulkDelete removes multiple entries from the store // Store Interface. -func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { +func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { m.logger.Debug("Executing BulkDelete request") tx, err := m.db.Begin() @@ -334,7 +335,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Delete(&da) + err = m.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -350,7 +351,7 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { // Get returns an entity from store // Store Interface. -func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *MySQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.logger.Debug("Getting state value from MySql") if req.Key == "" { @@ -368,7 +369,7 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { `SELECT value, eTag, isbinary FROM %s WHERE id = ?`, m.tableName, // m.tableName is sanitized ) - err := m.db.QueryRow(query, req.Key).Scan(&value, &eTag, &isBinary) + err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary) if err != nil { // If no rows exist, return an empty response, otherwise return an error. if errors.Is(err, sql.ErrNoRows) { @@ -410,13 +411,13 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { // Set adds/updates an entity on store // Store Interface. -func (m *MySQL) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) +func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, m.setValue, req) } // setValue is an internal implementation of set to enable passing the logic // to state.SetWithRetries as a func. -func (m *MySQL) setValue(req *state.SetRequest) error { +func (m *MySQL) setValue(ctx context.Context, req *state.SetRequest) error { m.logger.Debug("Setting state value in MySql") err := state.CheckRequestOptions(req.Options) @@ -457,7 +458,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error { `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary) + result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary) } else if req.ETag != nil && *req.ETag != "" { // When an eTag is provided do an update - not insert //nolint:gosec @@ -465,7 +466,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error { `UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.Exec(query, enc, eTag, isBinary, req.Key, *req.ETag) + result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag) } else { // If this is a duplicate MySQL returns that two rows affected maxRows = 2 @@ -474,7 +475,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error { `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) + result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) } if err != nil { @@ -508,7 +509,7 @@ func (m *MySQL) setValue(req *state.SetRequest) error { // BulkSet adds/updates multiple entities on store // Store Interface. -func (m *MySQL) BulkSet(req []state.SetRequest) error { +func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error { m.logger.Debug("Executing BulkSet request") tx, err := m.db.Begin() @@ -518,7 +519,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for i := range req { - err = m.Set(&req[i]) + err = m.Set(ctx, &req[i]) if err != nil { tx.Rollback() return err @@ -531,7 +532,7 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { // Multi handles multiple transactions. // TransactionalStore Interface. -func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { +func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { m.logger.Debug("Executing Multi request") tx, err := m.db.Begin() @@ -548,7 +549,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { return err } - err = m.Set(&setReq) + err = m.Set(ctx, &setReq) if err != nil { _ = tx.Rollback() return err @@ -561,7 +562,7 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { return err } - err = m.Delete(&delReq) + err = m.Delete(ctx, &delReq) if err != nil { _ = tx.Rollback() return err @@ -604,7 +605,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR } // BulkGet performs a bulks get operations. -func (m *MySQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (m *MySQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // by default, the store doesn't support bulk get // return false so daprd will fallback to call get() method one by one return false, nil, nil diff --git a/state/mysql/mysql_integration_test.go b/state/mysql/mysql_integration_test.go index db4382a9e..5f7034d7a 100644 --- a/state/mysql/mysql_integration_test.go +++ b/state/mysql/mysql_integration_test.go @@ -15,6 +15,7 @@ limitations under the License. package mysql import ( + "context" "crypto/tls" "crypto/x509" "database/sql" @@ -205,7 +206,7 @@ func TestMySQLIntegration(t *testing.T) { Key: "", } - response, getErr := mys.Get(getReq) + response, getErr := mys.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) }) @@ -242,7 +243,7 @@ func TestMySQLIntegration(t *testing.T) { Key: "", } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err, "Error was not nil when setting item with no key.") }) @@ -302,7 +303,7 @@ func TestMySQLIntegration(t *testing.T) { Value: newValue, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err, "Error was not thrown using old eTag") }) @@ -318,7 +319,7 @@ func TestMySQLIntegration(t *testing.T) { Value: value, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NotNil(t, err) }) @@ -338,7 +339,7 @@ func TestMySQLIntegration(t *testing.T) { ETag: &eTag, } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) }) @@ -349,7 +350,7 @@ func TestMySQLIntegration(t *testing.T) { Key: "", } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) }) @@ -361,7 +362,7 @@ func TestMySQLIntegration(t *testing.T) { Key: randomKey(), } - err := mys.Delete(deleteReq) + err := mys.Delete(context.TODO(), deleteReq) assert.Nil(t, err) }) @@ -378,7 +379,7 @@ func TestMySQLIntegration(t *testing.T) { }, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.NoError(t, err) // Get the etag @@ -396,7 +397,7 @@ func TestMySQLIntegration(t *testing.T) { }, } - err = mys.Set(setReq) + err = mys.Set(context.TODO(), setReq) assert.ErrorContains(t, err, "Duplicate entry") // Insert with invalid etag should fail on existing keys @@ -409,7 +410,7 @@ func TestMySQLIntegration(t *testing.T) { }, } - err = mys.Set(setReq) + err = mys.Set(context.TODO(), setReq) assert.ErrorContains(t, err, "possible etag mismatch") // Insert with valid etag should succeed on existing keys @@ -422,7 +423,7 @@ func TestMySQLIntegration(t *testing.T) { }, } - err = mys.Set(setReq) + err = mys.Set(context.TODO(), setReq) assert.NoError(t, err) // Insert with an etag should fail on new keys @@ -435,7 +436,7 @@ func TestMySQLIntegration(t *testing.T) { }, } - err = mys.Set(setReq) + err = mys.Set(context.TODO(), setReq) assert.ErrorContains(t, err, "possible etag mismatch") }) @@ -474,7 +475,7 @@ func TestMySQLIntegration(t *testing.T) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -510,7 +511,7 @@ func TestMySQLIntegration(t *testing.T) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -537,7 +538,7 @@ func TestMySQLIntegration(t *testing.T) { }) } - err := mys.Multi(&state.TransactionalStateRequest{ + err := mys.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -562,7 +563,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { }, } - err := mys.BulkSet(setReq) + err := mys.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -576,7 +577,7 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { }, } - err = mys.BulkDelete(deleteReq) + err = mys.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -596,7 +597,7 @@ func setItem(t *testing.T, mys *MySQL, key string, value interface{}, eTag *stri Value: value, } - err := mys.Set(setReq) + err := mys.Set(context.TODO(), setReq) assert.Nil(t, err, "Error setting an item") itemExists := storeItemExists(t, key) assert.True(t, itemExists, "Item does not exist after being set") @@ -608,7 +609,7 @@ func getItem(t *testing.T, mys *MySQL, key string) (*state.GetResponse, *fakeIte Options: state.GetStateOption{}, } - response, getErr := mys.Get(getReq) + response, getErr := mys.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -624,7 +625,7 @@ func deleteItem(t *testing.T, mys *MySQL, key string, eTag *string) { Options: state.DeleteStateOption{}, } - deleteErr := mys.Delete(deleteReq) + deleteErr := mys.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr, "There was an error deleting a record") assert.False(t, storeItemExists(t, key), "Item still exists after delete") } diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index ed7acd4c9..67ee04516 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -15,6 +15,7 @@ limitations under the License. package mysql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -161,7 +162,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) { m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError")) // Act - err := m.mySQL.Multi(nil) + err := m.mySQL.Multi(context.TODO(), nil) // Assert assert.NotNil(t, err, "no error returned") @@ -180,7 +181,7 @@ func TestMySQLBulkDeleteRollbackDeletes(t *testing.T) { deletes := []state.DeleteRequest{createDeleteRequest()} // Act - err := m.mySQL.BulkDelete(deletes) + err := m.mySQL.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err, "no error returned") @@ -199,7 +200,7 @@ func TestMySQLBulkSetRollbackSets(t *testing.T) { sets := []state.SetRequest{createSetRequest()} // Act - err := m.mySQL.BulkSet(sets) + err := m.mySQL.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err, "no error returned") @@ -232,7 +233,7 @@ func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) { } // Act - err := m.mySQL.Multi(&request) + err := m.mySQL.Multi(context.TODO(), &request) // Assert assert.Nil(t, err, "error returned") @@ -248,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) { request.Options.Consistency = "Invalid" // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -263,7 +264,7 @@ func TestSetHandlesNoKey(t *testing.T) { request.Key = "" // Act - err := m.mySQL.Set(&request) + err := m.mySQL.Set(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -283,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -302,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -315,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -327,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -338,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -352,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -369,7 +370,7 @@ func TestMySQLDeleteHandlesNoKey(t *testing.T) { request.Key = "" // Act - err := m.mySQL.Delete(&request) + err := m.mySQL.Delete(context.TODO(), &request) // Asset assert.NotNil(t, err) @@ -388,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.Nil(t, err) @@ -405,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) { request := createDeleteRequest() // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -420,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(context.TODO(), &request) // Assert assert.NotNil(t, err) @@ -441,7 +442,7 @@ func TestGetHandlesNoRows(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err, "returned error") @@ -458,7 +459,7 @@ func TestGetHandlesNoKey(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.NotNil(t, err, "returned error") @@ -478,7 +479,7 @@ func TestGetHandlesGenericError(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.NotNil(t, err) @@ -499,7 +500,7 @@ func TestGetSucceeds(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err) @@ -517,7 +518,7 @@ func TestGetSucceeds(t *testing.T) { } // Act - response, err := m.mySQL.Get(request) + response, err := m.mySQL.Get(context.TODO(), request) // Assert assert.Nil(t, err) @@ -711,7 +712,7 @@ func TestBulkGetReturnsNil(t *testing.T) { m, _ := mockDatabase(t) // Act - supported, response, err := m.mySQL.BulkGet(nil) + supported, response, err := m.mySQL.BulkGet(context.TODO(), nil) // Assert assert.Nil(t, err, `returned err`) @@ -730,7 +731,7 @@ func TestMultiWithNoRequestsDoesNothing(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -750,7 +751,7 @@ func TestInvalidMultiAction(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -789,7 +790,7 @@ func TestValidSetRequest(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -810,7 +811,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -834,7 +835,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -858,7 +859,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -879,7 +880,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -902,7 +903,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) @@ -941,7 +942,7 @@ func TestMultiOperationOrder(t *testing.T) { m.mock1.ExpectCommit() // Act - err := m.mySQL.Multi(&state.TransactionalStateRequest{ + err := m.mySQL.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: ops, }) diff --git a/state/oci/objectstorage/objectstorage.go b/state/oci/objectstorage/objectstorage.go index d4ad1c16a..8351c821f 100644 --- a/state/oci/objectstorage/objectstorage.go +++ b/state/oci/objectstorage/objectstorage.go @@ -131,15 +131,15 @@ func (r *StateStore) Features() []state.Feature { return r.features } -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r.logger.Debugf("Delete entry from OCI Object Storage State Store with key ", req.Key) - err := r.deleteDocument(req) + err := r.deleteDocument(ctx, req) return err } -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { r.logger.Debugf("Get from OCI Object Storage State Store with key ", req.Key) - content, etag, err := r.readDocument((req)) + content, etag, err := r.readDocument(ctx, req) if err != nil { r.logger.Debugf("error %s", err) if err.Error() == "ObjectNotFound" { @@ -155,9 +155,9 @@ func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, err } -func (r *StateStore) Set(req *state.SetRequest) error { +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r.logger.Debugf("saving %s to OCI Object Storage State Store", req.Key) - return r.writeDocument(req) + return r.writeDocument(ctx, req) } func (r *StateStore) Ping() error { @@ -267,7 +267,7 @@ func getIdentityAuthenticationDetails(metadata map[string]string, meta *Metadata } // functions that bridge from the Dapr State API to the OCI ObjectStorage Client. -func (r *StateStore) writeDocument(req *state.SetRequest) error { +func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) error { if len(req.Key) == 0 || req.Key == "" { return fmt.Errorf("key for value to set was missing from request") } @@ -286,7 +286,6 @@ func (r *StateStore) writeDocument(req *state.SetRequest) error { objectName := getFileName(req.Key) content := r.marshal(req) objectLength := int64(len(content)) - ctx := context.Background() etag := req.ETag if req.Options.Concurrency != state.FirstWrite { etag = nil @@ -315,12 +314,11 @@ func (r *StateStore) convertTTLtoExpiryTime(req *state.SetRequest, metadata map[ return nil } -func (r *StateStore) readDocument(req *state.GetRequest) ([]byte, *string, error) { +func (r *StateStore) readDocument(ctx context.Context, req *state.GetRequest) ([]byte, *string, error) { if len(req.Key) == 0 || req.Key == "" { return nil, nil, fmt.Errorf("key for value to get was missing from request") } objectName := getFileName(req.Key) - ctx := context.Background() content, etag, meta, err := r.client.getObject(ctx, objectName) if err != nil { r.logger.Debugf("download file %s, err %s", req.Key, err) @@ -348,13 +346,12 @@ func (r *StateStore) pingBucket() error { return nil } -func (r *StateStore) deleteDocument(req *state.DeleteRequest) error { +func (r *StateStore) deleteDocument(ctx context.Context, req *state.DeleteRequest) error { if len(req.Key) == 0 || req.Key == "" { return fmt.Errorf("key for value to delete was missing from request") } objectName := getFileName(req.Key) - ctx := context.Background() etag := req.ETag if req.Options.Concurrency != state.FirstWrite { etag = nil diff --git a/state/oci/objectstorage/objectstorage_integration_test.go b/state/oci/objectstorage/objectstorage_integration_test.go index d74e5028b..4b11f6d8e 100644 --- a/state/oci/objectstorage/objectstorage_integration_test.go +++ b/state/oci/objectstorage/objectstorage_integration_test.go @@ -4,6 +4,7 @@ package objectstorage // go test -v github.com/dapr/components-contrib/state/oci/objectstorage. import ( + "context" "fmt" "os" "testing" @@ -88,16 +89,16 @@ func testGet(t *testing.T, ociProperties map[string]string) { t.Run("Get an non-existing key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "xyzq"}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "xyzq"}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Non-existing key must not be treated as error") }) t.Run("Get an existing key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "test-key", Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-key", Value: []byte("test-value")}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "test-key"}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-key"}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") @@ -105,9 +106,9 @@ func testGet(t *testing.T, ociProperties map[string]string) { t.Run("Get an existing composed key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: "test-app||test-key"}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") }) @@ -115,11 +116,11 @@ func testGet(t *testing.T, ociProperties map[string]string) { testKey := "unexpired-ttl-test-key" err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "100", })}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set despite TTL setting") }) @@ -127,23 +128,23 @@ func testGet(t *testing.T, ociProperties map[string]string) { testKey := "never-expiring-ttl-test-key" err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "-1", })}) assert.Nil(t, err) - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)") }) t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "1", })}) assert.Nil(t, err) time.Sleep(time.Second * 2) - getResponse, err := statestore.Get(&state.GetRequest{Key: "ttl-test-key"}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: "ttl-test-key"}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Expired element must not be treated as error") }) @@ -156,7 +157,7 @@ func testSet(t *testing.T, ociProperties map[string]string) { t.Run("Set without a key", func(t *testing.T) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")}) assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") }) t.Run("Regular Set Operation", func(t *testing.T) { @@ -164,9 +165,9 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") @@ -176,20 +177,20 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") - getResponse, err := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, err := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) assert.Nil(t, err) assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") }) t.Run("Regular Set Operation with TTL", func(t *testing.T) { testKey := "test-key-with-ttl" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "500", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "XXX", })}) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") @@ -200,25 +201,25 @@ func testSet(t *testing.T, ociProperties map[string]string) { err := statestore.Init(meta) assert.Nil(t, err) - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, _ := statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, _ := statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) etag := getResponse.ETag - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with proper etag should go fine") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with the old etag should be refused") // retrieve the latest etag - assigned by the previous set operation. - getResponse, _ = statestore.Get(&state.GetRequest{Key: testKey}) + getResponse, _ = statestore.Get(context.TODO(), &state.GetRequest{Key: testKey}) assert.NotNil(t, *getResponse.ETag, "ETag should be set") etag = getResponse.ETag - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("more-overwritten-value"), ETag: etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with the latest etag should be accepted") @@ -232,7 +233,7 @@ func testDelete(t *testing.T, ociProperties map[string]string) { t.Run("Delete without a key", func(t *testing.T) { err := s.Init(m) assert.Nil(t, err) - err = s.Delete(&state.DeleteRequest{}) + err = s.Delete(context.TODO(), &state.DeleteRequest{}) assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") }) t.Run("Regular Delete Operation", func(t *testing.T) { @@ -240,9 +241,9 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - err = s.Delete(&state.DeleteRequest{Key: testKey}) + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") }) t.Run("Regular Delete Operation for composite key", func(t *testing.T) { @@ -250,13 +251,13 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper composite key should be errorfree") - err = s.Delete(&state.DeleteRequest{Key: testKey}) + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper composite key should be errorfree") }) t.Run("Delete with an unknown key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"}) assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") }) @@ -265,18 +266,18 @@ func testDelete(t *testing.T, ociProperties map[string]string) { err := s.Init(m) assert.Nil(t, err) // create document. - err = s.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err = s.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") - getResponse, _ := s.Get(&state.GetRequest{Key: testKey}) + getResponse, _ := s.Get(context.TODO(), &state.GetRequest{Key: testKey}) etag := getResponse.ETag incorrectETag := "someRandomETag" - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{ + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: etag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Deleting value with proper etag should go fine") diff --git a/state/oci/objectstorage/objectstorage_test.go b/state/oci/objectstorage/objectstorage_test.go index b2d4b1a46..146710ea1 100644 --- a/state/oci/objectstorage/objectstorage_test.go +++ b/state/oci/objectstorage/objectstorage_test.go @@ -236,24 +236,24 @@ func TestGetWithMockClient(t *testing.T) { s.client = mockClient t.Parallel() t.Run("Test regular Get", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-key"}) + getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-key"}) assert.True(t, mockClient.getIsCalled, "function Get should be invoked on the mockClient") assert.Equal(t, "Hello World", string(getResponse.Data), "Value retrieved should be equal to value set") assert.NotNil(t, *getResponse.ETag, "ETag should be set") assert.Nil(t, err) }) t.Run("Test Get with composite key", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-app||test-key"}) + getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-app||test-key"}) assert.Equal(t, "Hello Continent", string(getResponse.Data), "Value retrieved should be equal to value set") assert.Nil(t, err) }) t.Run("Test Get with an unknown key", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "unknownKey"}) + getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "unknownKey"}) assert.Nil(t, getResponse.Data, "No value should be retrieved for an unknown key") assert.Nil(t, err, "404", "Not finding an object because of unknown key should not result in an error") }) t.Run("Test expired element (because of TTL) ", func(t *testing.T) { - getResponse, err := s.Get(&state.GetRequest{Key: "test-expired-ttl-key"}) + getResponse, err := s.Get(context.TODO(), &state.GetRequest{Key: "test-expired-ttl-key"}) assert.Nil(t, getResponse.Data, "No value should be retrieved for an expired state element") assert.Nil(t, err, "Not returning an object because of expiration should not result in an error") }) @@ -289,28 +289,28 @@ func TestSetWithMockClient(t *testing.T) { mockClient := &mockedObjectStoreClient{} statestore.client = mockClient t.Run("Set without a key", func(t *testing.T) { - err := statestore.Set(&state.SetRequest{Value: []byte("test-value")}) + err := statestore.Set(context.TODO(), &state.SetRequest{Value: []byte("test-value")}) assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") }) t.Run("Regular Set Operation", func(t *testing.T) { testKey := "test-key" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value")}) + err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) assert.Nil(t, err, "Setting a value with a proper key should be errorfree") assert.True(t, mockClient.putIsCalled, "function put should be invoked on the mockClient") }) t.Run("Regular Set Operation with TTL", func(t *testing.T) { testKey := "test-key" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "5", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "XXX", })}) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "1", })}) assert.Nil(t, err, "Setting a value with a proper key and a correct TTL value should be errorfree") @@ -320,22 +320,22 @@ func TestSetWithMockClient(t *testing.T) { incorrectETag := "notTheCorrectETag" etag := "correctETag" - err := statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{ + err := statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &incorrectETag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with wrong etag should fail") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: &etag, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Updating value with proper etag should go fine") - err = statestore.Set(&state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ + err = statestore.Set(context.TODO(), &state.SetRequest{Key: testKey, Value: []byte("overwritten-value"), ETag: nil, Options: state.SetStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Updating value with concurrency policy at FirstWrite should fail when ETag is missing") @@ -348,34 +348,34 @@ func TestDeleteWithMockClient(t *testing.T) { mockClient := &mockedObjectStoreClient{} s.client = mockClient t.Run("Delete without a key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{}) + err := s.Delete(context.TODO(), &state.DeleteRequest{}) assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") }) t.Run("Delete with an unknown key", func(t *testing.T) { - err := s.Delete(&state.DeleteRequest{Key: "unknownKey"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "unknownKey"}) assert.Contains(t, err.Error(), "404", "Unknown Key results in error: http status code 404, object not found") }) t.Run("Regular Delete Operation", func(t *testing.T) { testKey := "test-key" - err := s.Delete(&state.DeleteRequest{Key: testKey}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey}) assert.Nil(t, err, "Deleting an existing value with a proper key should be errorfree") assert.True(t, mockClient.deleteIsCalled, "function delete should be invoked on the mockClient") }) t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) { testKey := "etag-test-delete-key" incorrectETag := "notTheCorrectETag" - err := s.Delete(&state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &incorrectETag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Deleting value with an incorrect etag should be prevented") etag := "correctETag" - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{ + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: &etag, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.Nil(t, err, "Deleting value with proper etag should go fine") - err = s.Delete(&state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{ + err = s.Delete(context.TODO(), &state.DeleteRequest{Key: testKey, ETag: nil, Options: state.DeleteStateOption{ Concurrency: state.FirstWrite, }}) assert.NotNil(t, err, "Asking for FirstWrite concurrency policy without ETag should fail") diff --git a/state/oracledatabase/dbaccess.go b/state/oracledatabase/dbaccess.go index 60e7ee336..188cf8366 100644 --- a/state/oracledatabase/dbaccess.go +++ b/state/oracledatabase/dbaccess.go @@ -14,6 +14,8 @@ limitations under the License. package oracledatabase import ( + "context" + "github.com/dapr/components-contrib/state" ) @@ -21,9 +23,9 @@ import ( type dbAccess interface { Init(metadata state.Metadata) error Ping() error - Set(req *state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error + Set(ctx context.Context, req *state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error Close() error // io.Closer. } diff --git a/state/oracledatabase/oracledatabase.go b/state/oracledatabase/oracledatabase.go index f607947ef..46d08f3d7 100644 --- a/state/oracledatabase/oracledatabase.go +++ b/state/oracledatabase/oracledatabase.go @@ -14,6 +14,7 @@ limitations under the License. package oracledatabase import ( + "context" "fmt" "github.com/dapr/components-contrib/state" @@ -59,38 +60,38 @@ func (o *OracleDatabase) Features() []state.Feature { } // Delete removes an entity from the store. -func (o *OracleDatabase) Delete(req *state.DeleteRequest) error { - return o.dbaccess.Delete(req) +func (o *OracleDatabase) Delete(ctx context.Context, req *state.DeleteRequest) error { + return o.dbaccess.Delete(ctx, req) } // BulkDelete removes multiple entries from the store. -func (o *OracleDatabase) BulkDelete(req []state.DeleteRequest) error { - return o.dbaccess.ExecuteMulti(nil, req) +func (o *OracleDatabase) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return o.dbaccess.ExecuteMulti(ctx, nil, req) } // Get returns an entity from store. -func (o *OracleDatabase) Get(req *state.GetRequest) (*state.GetResponse, error) { - return o.dbaccess.Get(req) +func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return o.dbaccess.Get(ctx, req) } // BulkGet performs a bulks get operations. -func (o *OracleDatabase) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance. return false, nil, nil } // Set adds/updates an entity on store. -func (o *OracleDatabase) Set(req *state.SetRequest) error { - return o.dbaccess.Set(req) +func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error { + return o.dbaccess.Set(ctx, req) } // BulkSet adds/updates multiple entities on store. -func (o *OracleDatabase) BulkSet(req []state.SetRequest) error { - return o.dbaccess.ExecuteMulti(req, nil) +func (o *OracleDatabase) BulkSet(ctx context.Context, req []state.SetRequest) error { + return o.dbaccess.ExecuteMulti(ctx, req, nil) } // Multi handles multiple transactions. Implements TransactionalStore. -func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error { +func (o *OracleDatabase) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { var deletes []state.DeleteRequest var sets []state.SetRequest for _, req := range request.Operations { @@ -115,7 +116,7 @@ func (o *OracleDatabase) Multi(request *state.TransactionalStateRequest) error { } if len(sets) > 0 || len(deletes) > 0 { - return o.dbaccess.ExecuteMulti(sets, deletes) + return o.dbaccess.ExecuteMulti(ctx, sets, deletes) } return nil diff --git a/state/oracledatabase/oracledatabase_integration_test.go b/state/oracledatabase/oracledatabase_integration_test.go index d176e6252..6d8e1caaf 100644 --- a/state/oracledatabase/oracledatabase_integration_test.go +++ b/state/oracledatabase/oracledatabase_integration_test.go @@ -15,6 +15,7 @@ limitations under the License. package oracledatabase import ( + "context" "database/sql" "encoding/json" "fmt" @@ -223,7 +224,7 @@ func deleteItemThatDoesNotExist(t *testing.T, ods *OracleDatabase) { deleteReq := &state.DeleteRequest{ Key: randomKey(), } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -242,7 +243,7 @@ func multiWithSetOnly(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -272,7 +273,7 @@ func multiWithDeleteOnly(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -315,7 +316,7 @@ func multiWithDeleteAndSet(t *testing.T, ods *OracleDatabase) { }) } - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -345,7 +346,7 @@ func deleteWithInvalidEtagFails(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.NotNil(t, err, "Deleting an item with the wrong etag while enforcing FirstWrite policy should fail") } @@ -353,7 +354,7 @@ func deleteWithNoKeyFails(t *testing.T, ods *OracleDatabase) { deleteReq := &state.DeleteRequest{ Key: "", } - err := ods.Delete(deleteReq) + err := ods.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -371,7 +372,7 @@ func newItemWithEtagFails(t *testing.T, ods *OracleDatabase) { }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -401,7 +402,7 @@ func updateWithOldEtagFails(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -423,7 +424,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err, "Setting the item should be successful") updateResponse, updatedItem := getItem(t, ods, key) assert.Equal(t, value, updatedItem) @@ -439,7 +440,7 @@ func updateAndDeleteWithEtagSucceeds(t *testing.T, ods *OracleDatabase) { Concurrency: state.FirstWrite, }, } - err = ods.Delete(deleteReq) + err = ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err, "Deleting an item with the right etag while enforcing FirstWrite policy should succeed") // Item is not in the data store. @@ -465,7 +466,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac Concurrency: state.LastWrite, }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err, "Setting the item should be successful") _, updatedItem := getItem(t, ods, key) assert.Equal(t, value, updatedItem) @@ -478,7 +479,7 @@ func updateAndDeleteWithWrongEtagAndNoFirstWriteSucceeds(t *testing.T, ods *Orac Concurrency: state.LastWrite, }, } - err = ods.Delete(deleteReq) + err = ods.Delete(context.TODO(), deleteReq) assert.Nil(t, err, "Deleting an item with the wrong etag but not enforcing FirstWrite policy should succeed") // Item is not in the data store. @@ -500,7 +501,7 @@ func getItemWithNoKey(t *testing.T, ods *OracleDatabase) { Key: "", } - response, getErr := ods.Get(getReq) + response, getErr := ods.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -548,7 +549,7 @@ func setTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) { }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) connectionString := getConnectionString() if getWalletLocation() != "" { @@ -580,10 +581,10 @@ func setNoTTLUpdatesExpiry(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "1000", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) delete(setReq.Metadata, "ttlInSeconds") - err = ods.Set(setReq) + err = ods.Set(context.TODO(), setReq) assert.Nil(t, err) connectionString := getConnectionString() if getWalletLocation() != "" { @@ -614,11 +615,11 @@ func expiredStateCannotBeRead(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "1", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) time.Sleep(time.Second * time.Duration(2)) - getResponse, err := ods.Get(&state.GetRequest{Key: key}) + getResponse, err := ods.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Expired element must not be treated as error") @@ -639,7 +640,7 @@ func unexpiredStateCanBeRead(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "10000", }, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) _, getValue := getItem(t, ods, key) assert.Equal(t, value.Color, getValue.Color, "Response must be as set") @@ -653,7 +654,7 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) { Key: "", } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -702,7 +703,7 @@ func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "XX", }), } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err, "Setting a value with a proper key and a incorrect TTL value should be produce an error") } @@ -714,7 +715,7 @@ func testSetItemWithNegativeTTL(t *testing.T, ods *OracleDatabase) { "ttlInSeconds": "-10", }), } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.NotNil(t, err, "Setting a value with a proper key and a negative (other than -1) TTL value should be produce an error") } @@ -731,7 +732,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) { }, } - err := ods.BulkSet(setReq) + err := ods.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -745,7 +746,7 @@ func testBulkSetAndBulkDelete(t *testing.T, ods *OracleDatabase) { }, } - err = ods.BulkDelete(deleteReq) + err = ods.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -812,7 +813,7 @@ func setItem(t *testing.T, ods *OracleDatabase, key string, value interface{}, e Options: setOptions, } - err := ods.Set(setReq) + err := ods.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists, "Item should exist after set has been executed ") @@ -824,7 +825,7 @@ func getItem(t *testing.T, ods *OracleDatabase, key string) (*state.GetResponse, Options: state.GetStateOption{}, } - response, getErr := ods.Get(getReq) + response, getErr := ods.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -840,7 +841,7 @@ func deleteItem(t *testing.T, ods *OracleDatabase, key string, etag *string) { Options: state.DeleteStateOption{}, } - deleteErr := ods.Delete(deleteReq) + deleteErr := ods.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key), "item should no longer exist after delete has been performed") } diff --git a/state/oracledatabase/oracledatabase_test.go b/state/oracledatabase/oracledatabase_test.go index 628821341..c784a6255 100644 --- a/state/oracledatabase/oracledatabase_test.go +++ b/state/oracledatabase/oracledatabase_test.go @@ -15,6 +15,7 @@ limitations under the License. package oracledatabase import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -48,23 +49,23 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error { return nil } @@ -84,7 +85,7 @@ func TestMultiWithNoRequestsReturnsNil(t *testing.T) { t.Parallel() var operations []state.TransactionalStateOperation ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -100,7 +101,7 @@ func TestInvalidMultiAction(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) @@ -116,7 +117,7 @@ func TestValidSetRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -132,7 +133,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) @@ -148,7 +149,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -164,7 +165,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) ods := createOracleDatabase(t) - err := ods.Multi(&state.TransactionalStateRequest{ + err := ods.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.NotNil(t, err) diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 5fba997c3..133bde306 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -14,6 +14,7 @@ limitations under the License. package oracledatabase import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -96,8 +97,8 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(o.setValue, req) +func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, o.setValue, req) } func parseTTL(requestMetadata map[string]string) (*int, error) { @@ -115,7 +116,7 @@ func parseTTL(requestMetadata map[string]string) (*int, error) { } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { +func (o *oracleDatabaseAccess) setValue(ctx context.Context, req *state.SetRequest) error { o.logger.Debug("Setting state value in OracleDatabase") err := state.CheckRequestOptions(req.Options) if err != nil { @@ -182,14 +183,14 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `, tableName) - result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds) + result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds) } else { // when first write policy is indicated, an existing record has to be updated - one that has the etag provided. updateStatement := fmt.Sprintf( `UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag WHERE key = :key AND etag = :etag`, tableName) - result, err = tx.Exec(updateStatement, value, binaryYN, etag, req.Key, *req.ETag) + result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag) } if err != nil { if req.ETag != nil && *req.ETag != "" { @@ -214,7 +215,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { o.logger.Debug("Getting state value from OracleDatabase") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -222,7 +223,7 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e var value string var binaryYN string var etag string - err := o.db.QueryRow(fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag) + err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag) if err != nil { // If no rows exist, return an empty response, otherwise return the error. if err == sql.ErrNoRows { @@ -253,12 +254,12 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e } // Delete removes an item from the state store. -func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(o.deleteValue, req) +func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, o.deleteValue, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { +func (o *oracleDatabaseAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { o.logger.Debug("Deleting state value from OracleDatabase") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -280,9 +281,9 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { } // QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied?? if req.Options.Concurrency != state.FirstWrite { - result, err = tx.Exec("DELETE FROM state WHERE key = :key", req.Key) + result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key) } else { - result, err = tx.Exec("DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag) + result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag) } if err != nil { if o.tx == nil { // not joining a preexisting transaction. @@ -303,7 +304,7 @@ func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { +func (o *oracleDatabaseAccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error { o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction") tx, err := o.db.Begin() if err != nil { @@ -313,7 +314,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s if len(deletes) > 0 { for _, d := range deletes { da := d // Fix for gosec G601: Implicit memory aliasing in for looo. - err = o.Delete(&da) + err = o.Delete(ctx, &da) if err != nil { tx.Rollback() return err @@ -323,7 +324,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s if len(sets) > 0 { for _, s := range sets { sa := s // Fix for gosec G601: Implicit memory aliasing in for looo. - err = o.Set(&sa) + err = o.Set(ctx, &sa) if err != nil { tx.Rollback() return err diff --git a/state/postgresql/dbaccess.go b/state/postgresql/dbaccess.go index d4575be3f..eb1f49a07 100644 --- a/state/postgresql/dbaccess.go +++ b/state/postgresql/dbaccess.go @@ -14,18 +14,20 @@ limitations under the License. package postgresql import ( + "context" + "github.com/dapr/components-contrib/state" ) // dbAccess is a private interface which enables unit testing of PostgreSQL. type dbAccess interface { Init(metadata state.Metadata) error - Set(req *state.SetRequest) error - BulkSet(req []state.SetRequest) error - Get(req *state.GetRequest) (*state.GetResponse, error) - Delete(req *state.DeleteRequest) error - BulkDelete(req []state.DeleteRequest) error - ExecuteMulti(req *state.TransactionalStateRequest) error - Query(req *state.QueryRequest) (*state.QueryResponse, error) + Set(ctx context.Context, req *state.SetRequest) error + BulkSet(ctx context.Context, req []state.SetRequest) error + Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + Delete(ctx context.Context, req *state.DeleteRequest) error + BulkDelete(ctx context.Context, req []state.DeleteRequest) error + ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error + Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) Close() error // io.Closer } diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 16bb36874..ae495f62b 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -14,6 +14,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -98,12 +99,12 @@ func (p *postgresDBAccess) Init(metadata state.Metadata) error { } // Set makes an insert or update to the database. -func (p *postgresDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) +func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, p.setValue, req) } // setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *postgresDBAccess) setValue(req *state.SetRequest) error { +func (p *postgresDBAccess) setValue(ctx context.Context, req *state.SetRequest) error { p.logger.Debug("Setting state value in PostgreSQL") err := state.CheckRequestOptions(req.Options) @@ -134,7 +135,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { // Sprintf is required for table name because sql.DB does not substitute parameters for table names. // Other parameters use sql.DB parameter substitution. if req.ETag == nil { - result, err = p.db.Exec(fmt.Sprintf( + result, err = p.db.ExecContext(ctx, fmt.Sprintf( `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3) ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`, tableName), req.Key, value, isBinary) @@ -148,7 +149,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { etag := uint32(etag64) // When an etag is provided do an update - no insert - result, err = p.db.Exec(fmt.Sprintf( + result, err = p.db.ExecContext(ctx, fmt.Sprintf( `UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW() WHERE key = $3 AND xmin = $4;`, tableName), value, isBinary, req.Key, etag) @@ -174,7 +175,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { return nil } -func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { +func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error { p.logger.Debug("Executing BulkSet request") tx, err := p.db.Begin() if err != nil { @@ -184,7 +185,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for _, s := range req { sa := s // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Set(&sa) + err = p.Set(ctx, &sa) if err != nil { tx.Rollback() @@ -199,7 +200,7 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { } // Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { p.logger.Debug("Getting state value from PostgreSQL") if req.Key == "" { return nil, fmt.Errorf("missing key in get operation") @@ -208,7 +209,7 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error var value string var isBinary bool var etag int - err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) + err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag) if err != nil { // If no rows exist, return an empty response, otherwise return the error. if err == sql.ErrNoRows { @@ -245,12 +246,12 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error } // Delete removes an item from the state store. -func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) +func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, p.deleteValue, req) } // deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { +func (p *postgresDBAccess) deleteValue(ctx context.Context, req *state.DeleteRequest) error { p.logger.Debug("Deleting state value from PostgreSQL") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -260,7 +261,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { var err error if req.ETag == nil { - result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) + result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key) } else { // Convert req.ETag to uint32 for postgres XID compatibility var etag64 uint64 @@ -270,7 +271,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { } etag := uint32(etag64) - result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag) + result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag) } if err != nil { @@ -289,7 +290,7 @@ func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { return nil } -func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { +func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { p.logger.Debug("Executing BulkDelete request") tx, err := p.db.Begin() if err != nil { @@ -299,7 +300,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Delete(&da) + err = p.Delete(ctx, &da) if err != nil { tx.Rollback() @@ -313,7 +314,7 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { return err } -func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error { +func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error { p.logger.Debug("Executing PostgreSQL transaction") tx, err := p.db.Begin() @@ -332,7 +333,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest return err } - err = p.Set(&setReq) + err = p.Set(ctx, &setReq) if err != nil { tx.Rollback() return err @@ -347,7 +348,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest return err } - err = p.Delete(&delReq) + err = p.Delete(ctx, &delReq) if err != nil { tx.Rollback() return err @@ -365,7 +366,7 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest } // Query executes a query against store. -func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { p.logger.Debug("Getting query value from PostgreSQL") q := &Query{ query: "", @@ -375,7 +376,7 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, if err := qbuilder.BuildQuery(&req.Query); err != nil { return &state.QueryResponse{}, err } - data, token, err := q.execute(p.logger, p.db) + data, token, err := q.execute(ctx, p.logger, p.db) if err != nil { return &state.QueryResponse{}, err } diff --git a/state/postgresql/postgresdbaccess_test.go b/state/postgresql/postgresdbaccess_test.go index 82057f501..e39c58512 100644 --- a/state/postgresql/postgresdbaccess_test.go +++ b/state/postgresql/postgresdbaccess_test.go @@ -15,6 +15,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "testing" @@ -110,7 +111,7 @@ func TestMultiWithNoRequests(t *testing.T) { var operations []state.TransactionalStateOperation // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -134,7 +135,7 @@ func TestInvalidMultiInvalidAction(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -159,7 +160,7 @@ func TestValidSetRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -183,7 +184,7 @@ func TestInvalidMultiSetRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -207,7 +208,7 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -232,7 +233,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -256,7 +257,7 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -280,7 +281,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { }) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -312,7 +313,7 @@ func TestMultiOperationOrder(t *testing.T) { ) // Act - err := m.pgDba.ExecuteMulti(&state.TransactionalStateRequest{ + err := m.pgDba.ExecuteMulti(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) @@ -335,7 +336,7 @@ func TestInvalidBulkSetNoKey(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -357,7 +358,7 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.NotNil(t, err) @@ -380,7 +381,7 @@ func TestValidBulkSet(t *testing.T) { }) // Act - err := m.pgDba.BulkSet(sets) + err := m.pgDba.BulkSet(context.TODO(), sets) // Assert assert.Nil(t, err) @@ -401,7 +402,7 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) { }) // Act - err := m.pgDba.BulkDelete(deletes) + err := m.pgDba.BulkDelete(context.TODO(), deletes) // Assert assert.NotNil(t, err) @@ -423,7 +424,7 @@ func TestValidBulkDelete(t *testing.T) { }) // Act - err := m.pgDba.BulkDelete(deletes) + err := m.pgDba.BulkDelete(context.TODO(), deletes) // Assert assert.Nil(t, err) diff --git a/state/postgresql/postgresql.go b/state/postgresql/postgresql.go index 59782b298..f37aa564c 100644 --- a/state/postgresql/postgresql.go +++ b/state/postgresql/postgresql.go @@ -14,6 +14,8 @@ limitations under the License. package postgresql import ( + "context" + "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) @@ -53,44 +55,44 @@ func (p *PostgreSQL) Features() []state.Feature { } // Delete removes an entity from the store. -func (p *PostgreSQL) Delete(req *state.DeleteRequest) error { - return p.dbaccess.Delete(req) +func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) error { + return p.dbaccess.Delete(ctx, req) } // BulkDelete removes multiple entries from the store. -func (p *PostgreSQL) BulkDelete(req []state.DeleteRequest) error { - return p.dbaccess.BulkDelete(req) +func (p *PostgreSQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { + return p.dbaccess.BulkDelete(ctx, req) } // Get returns an entity from store. -func (p *PostgreSQL) Get(req *state.GetRequest) (*state.GetResponse, error) { - return p.dbaccess.Get(req) +func (p *PostgreSQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + return p.dbaccess.Get(ctx, req) } // BulkGet performs a bulks get operations. -func (p *PostgreSQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (p *PostgreSQL) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with ExecuteMulti for performance return false, nil, nil } // Set adds/updates an entity on store. -func (p *PostgreSQL) Set(req *state.SetRequest) error { - return p.dbaccess.Set(req) +func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error { + return p.dbaccess.Set(ctx, req) } // BulkSet adds/updates multiple entities on store. -func (p *PostgreSQL) BulkSet(req []state.SetRequest) error { - return p.dbaccess.BulkSet(req) +func (p *PostgreSQL) BulkSet(ctx context.Context, req []state.SetRequest) error { + return p.dbaccess.BulkSet(ctx, req) } // Multi handles multiple transactions. Implements TransactionalStore. -func (p *PostgreSQL) Multi(request *state.TransactionalStateRequest) error { - return p.dbaccess.ExecuteMulti(request) +func (p *PostgreSQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { + return p.dbaccess.ExecuteMulti(ctx, request) } // Query executes a query against store. -func (p *PostgreSQL) Query(req *state.QueryRequest) (*state.QueryResponse, error) { - return p.dbaccess.Query(req) +func (p *PostgreSQL) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { + return p.dbaccess.Query(ctx, req) } // Close implements io.Closer. diff --git a/state/postgresql/postgresql_integration_test.go b/state/postgresql/postgresql_integration_test.go index 52ba490a5..b2ac835fe 100644 --- a/state/postgresql/postgresql_integration_test.go +++ b/state/postgresql/postgresql_integration_test.go @@ -15,6 +15,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "encoding/json" "fmt" @@ -192,7 +193,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) { deleteReq := &state.DeleteRequest{ Key: randomKey(), } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, err) } @@ -211,7 +212,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -241,7 +242,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -284,7 +285,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) { }) } - err := pgs.Multi(&state.TransactionalStateRequest{ + err := pgs.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: operations, }) assert.Nil(t, err) @@ -311,7 +312,7 @@ func deleteWithInvalidEtagFails(t *testing.T, pgs *PostgreSQL) { Key: key, ETag: &etag, } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -319,7 +320,7 @@ func deleteWithNoKeyFails(t *testing.T, pgs *PostgreSQL) { deleteReq := &state.DeleteRequest{ Key: "", } - err := pgs.Delete(deleteReq) + err := pgs.Delete(context.TODO(), deleteReq) assert.NotNil(t, err) } @@ -334,7 +335,7 @@ func newItemWithEtagFails(t *testing.T, pgs *PostgreSQL) { Value: value, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -360,7 +361,7 @@ func updateWithOldEtagFails(t *testing.T, pgs *PostgreSQL) { ETag: originalEtag, Value: newValue, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -402,7 +403,7 @@ func getItemWithNoKey(t *testing.T, pgs *PostgreSQL) { Key: "", } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.NotNil(t, getErr) assert.Nil(t, response) } @@ -433,7 +434,7 @@ func setItemWithNoKey(t *testing.T, pgs *PostgreSQL) { Key: "", } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.NotNil(t, err) } @@ -450,7 +451,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { }, } - err := pgs.BulkSet(setReq) + err := pgs.BulkSet(context.TODO(), setReq) assert.Nil(t, err) assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[1].Key)) @@ -464,7 +465,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { }, } - err = pgs.BulkDelete(deleteReq) + err = pgs.BulkDelete(context.TODO(), deleteReq) assert.Nil(t, err) assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[1].Key)) @@ -521,7 +522,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag Value: value, } - err := pgs.Set(setReq) + err := pgs.Set(context.TODO(), setReq) assert.Nil(t, err) itemExists := storeItemExists(t, key) assert.True(t, itemExists) @@ -533,7 +534,7 @@ func getItem(t *testing.T, pgs *PostgreSQL, key string) (*state.GetResponse, *fa Options: state.GetStateOption{}, } - response, getErr := pgs.Get(getReq) + response, getErr := pgs.Get(context.TODO(), getReq) assert.Nil(t, getErr) assert.NotNil(t, response) outputObject := &fakeItem{} @@ -549,7 +550,7 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) { Options: state.DeleteStateOption{}, } - deleteErr := pgs.Delete(deleteReq) + deleteErr := pgs.Delete(context.TODO(), deleteReq) assert.Nil(t, deleteErr) assert.False(t, storeItemExists(t, key)) } diff --git a/state/postgresql/postgresql_query.go b/state/postgresql/postgresql_query.go index b18866119..09530c0a0 100644 --- a/state/postgresql/postgresql_query.go +++ b/state/postgresql/postgresql_query.go @@ -15,6 +15,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "fmt" "strconv" @@ -139,8 +140,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error { return nil } -func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { - rows, err := db.Query(q.query, q.params...) +func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { + rows, err := db.QueryContext(ctx, q.query, q.params...) if err != nil { return nil, "", err } diff --git a/state/postgresql/postgresql_test.go b/state/postgresql/postgresql_test.go index 99ab088ef..1697909b1 100644 --- a/state/postgresql/postgresql_test.go +++ b/state/postgresql/postgresql_test.go @@ -15,6 +15,7 @@ limitations under the License. package postgresql import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -43,37 +44,37 @@ func (m *fakeDBaccess) Init(metadata state.Metadata) error { return nil } -func (m *fakeDBaccess) Set(req *state.SetRequest) error { +func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { m.setExecuted = true return nil } -func (m *fakeDBaccess) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { m.getExecuted = true return nil, nil } -func (m *fakeDBaccess) Delete(req *state.DeleteRequest) error { +func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { m.deleteExecuted = true return nil } -func (m *fakeDBaccess) BulkSet(req []state.SetRequest) error { +func (m *fakeDBaccess) BulkSet(ctx context.Context, req []state.SetRequest) error { return nil } -func (m *fakeDBaccess) BulkDelete(req []state.DeleteRequest) error { +func (m *fakeDBaccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { return nil } -func (m *fakeDBaccess) ExecuteMulti(req *state.TransactionalStateRequest) error { +func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error { return nil } -func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { return nil, nil } diff --git a/state/redis/redis.go b/state/redis/redis.go index 3ff7c55fe..2f9edfbfc 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -195,7 +195,7 @@ func (r *StateStore) parseConnectedSlaves(res string) int { return 0 } -func (r *StateStore) deleteValue(req *state.DeleteRequest) error { +func (r *StateStore) deleteValue(ctx context.Context, req *state.DeleteRequest) error { if req.ETag == nil { etag := "0" req.ETag = &etag @@ -207,7 +207,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { } else { delQuery = delDefaultQuery } - _, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() + _, err := r.client.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() if err != nil { return state.NewETagError(state.ETagMismatch, err) } @@ -216,17 +216,17 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { } // Delete performs a delete operation. -func (r *StateStore) Delete(req *state.DeleteRequest) error { +func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err } - return state.DeleteWithOptions(r.deleteValue, req) + return state.DeleteWithOptions(ctx, r.deleteValue, req) } -func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) { - res, err := r.client.Do(r.ctx, "GET", req.Key).Result() +func (r *StateStore) directGet(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + res, err := r.client.Do(ctx, "GET", req.Key).Result() if err != nil { return nil, err } @@ -242,10 +242,10 @@ func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error }, nil } -func (r *StateStore) getDefault(req *state.GetRequest) (*state.GetResponse, error) { - res, err := r.client.Do(r.ctx, "HGETALL", req.Key).Result() // Prefer values with ETags +func (r *StateStore) getDefault(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + res, err := r.client.Do(ctx, "HGETALL", req.Key).Result() // Prefer values with ETags if err != nil { - return r.directGet(req) // Falls back to original get for backward compats. + return r.directGet(ctx, req) // Falls back to original get for backward compats. } if res == nil { return &state.GetResponse{}, nil @@ -304,12 +304,12 @@ func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error) } // Get retrieves state from redis with a key. -func (r *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType { return r.getJSON(req) } - return r.getDefault(req) + return r.getDefault(ctx, req) } type jsonEntry struct { @@ -317,7 +317,7 @@ type jsonEntry struct { Version *int `json:"version,omitempty"` } -func (r *StateStore) setValue(req *state.SetRequest) error { +func (r *StateStore) setValue(ctx context.Context, req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -350,7 +350,7 @@ func (r *StateStore) setValue(req *state.SetRequest) error { bt, _ = utils.Marshal(req.Value, r.json.Marshal) } - err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err() + err = r.client.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Err() if err != nil { if req.ETag != nil { return state.NewETagError(state.ETagMismatch, err) @@ -360,21 +360,21 @@ func (r *StateStore) setValue(req *state.SetRequest) error { } if ttl != nil && *ttl > 0 { - _, err = r.client.Do(r.ctx, "EXPIRE", req.Key, *ttl).Result() + _, err = r.client.Do(ctx, "EXPIRE", req.Key, *ttl).Result() if err != nil { return fmt.Errorf("failed to set key %s ttl: %s", req.Key, err) } } if ttl != nil && *ttl <= 0 { - _, err = r.client.Do(r.ctx, "PERSIST", req.Key).Result() + _, err = r.client.Do(ctx, "PERSIST", req.Key).Result() if err != nil { return fmt.Errorf("failed to persist key %s: %s", req.Key, err) } } if req.Options.Consistency == state.Strong && r.replicas > 0 { - _, err = r.client.Do(r.ctx, "WAIT", r.replicas, 1000).Result() + _, err = r.client.Do(ctx, "WAIT", r.replicas, 1000).Result() if err != nil { return fmt.Errorf("redis waiting for %v replicas to acknowledge write, err: %s", r.replicas, err.Error()) } @@ -384,12 +384,12 @@ func (r *StateStore) setValue(req *state.SetRequest) error { } // Set saves state into redis. -func (r *StateStore) Set(req *state.SetRequest) error { - return state.SetWithOptions(r.setValue, req) +func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { + return state.SetWithOptions(ctx, r.setValue, req) } // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. -func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { +func (r *StateStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { var setQuery, delQuery string var isJSON bool if contentType, ok := request.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType { @@ -423,12 +423,12 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { } else { bt, _ = utils.Marshal(req.Value, r.json.Marshal) } - pipe.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt) + pipe.Do(ctx, "EVAL", setQuery, 1, req.Key, ver, bt) if ttl != nil && *ttl > 0 { - pipe.Do(r.ctx, "EXPIRE", req.Key, *ttl) + pipe.Do(ctx, "EXPIRE", req.Key, *ttl) } if ttl != nil && *ttl <= 0 { - pipe.Do(r.ctx, "PERSIST", req.Key) + pipe.Do(ctx, "PERSIST", req.Key) } } else if o.Operation == state.Delete { req := o.Request.(state.DeleteRequest) @@ -436,11 +436,11 @@ func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { etag := "0" req.ETag = &etag } - pipe.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag) + pipe.Do(ctx, "EVAL", delQuery, 1, req.Key, *req.ETag) } } - _, err := pipe.Exec(r.ctx) + _, err := pipe.Exec(ctx) return err } @@ -514,7 +514,7 @@ func (r *StateStore) parseTTL(req *state.SetRequest) (*int, error) { } // Query executes a query against store. -func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) { +func (r *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { indexName, ok := daprmetadata.TryGetQueryIndexName(req.Metadata) if !ok { return nil, fmt.Errorf("query index not found") @@ -529,7 +529,7 @@ func (r *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error if err := qbuilder.BuildQuery(&req.Query); err != nil { return &state.QueryResponse{}, err } - data, token, err := q.execute(r.ctx, r.client) + data, token, err := q.execute(ctx, r.client) if err != nil { return &state.QueryResponse{}, err } diff --git a/state/redis/redis_test.go b/state/redis/redis_test.go index fc99bd7a4..3279e2ab2 100644 --- a/state/redis/redis_test.go +++ b/state/redis/redis_test.go @@ -206,7 +206,7 @@ func TestTransactionalUpsert(t *testing.T) { } ss.ctx, ss.cancel = context.WithCancel(context.Background()) - err := ss.Multi(&state.TransactionalStateRequest{ + err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ { Operation: state.Upsert, @@ -273,13 +273,13 @@ func TestTransactionalDelete(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon", Value: "deathstar", }) etag := "1" - err := ss.Multi(&state.TransactionalStateRequest{ + err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{{ Operation: state.Delete, Request: state.DeleteRequest{ @@ -331,7 +331,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) t.Run("TTL: Only global specified", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", }) @@ -342,7 +342,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { t.Run("TTL: Global and Request specified", func(t *testing.T) { requestTTL := 200 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", Metadata: map[string]string{ @@ -355,7 +355,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { }) t.Run("TTL: Global and Request specified", func(t *testing.T) { - err := ss.Multi(&state.TransactionalStateRequest{ + err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ { Operation: state.Upsert, @@ -424,7 +424,7 @@ func TestSetRequestWithTTL(t *testing.T) { t.Run("TTL specified", func(t *testing.T) { ttlInSeconds := 100 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", Metadata: map[string]string{ @@ -438,7 +438,7 @@ func TestSetRequestWithTTL(t *testing.T) { }) t.Run("TTL not specified", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon200", Value: "deathstar200", }) @@ -449,7 +449,7 @@ func TestSetRequestWithTTL(t *testing.T) { }) t.Run("TTL Changed for Existing Key", func(t *testing.T) { - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar300", }) @@ -458,7 +458,7 @@ func TestSetRequestWithTTL(t *testing.T) { // make the key no longer persistent ttlInSeconds := 123 - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar300", Metadata: map[string]string{ @@ -469,7 +469,7 @@ func TestSetRequestWithTTL(t *testing.T) { assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl) // make the key persistent again - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon300", Value: "deathstar301", Metadata: map[string]string{ @@ -493,12 +493,12 @@ func TestTransactionalDeleteNoEtag(t *testing.T) { ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. - ss.Set(&state.SetRequest{ + ss.Set(context.TODO(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", }) - err := ss.Multi(&state.TransactionalStateRequest{ + err := ss.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{{ Operation: state.Delete, Request: state.DeleteRequest{ diff --git a/state/request_options.go b/state/request_options.go index 23ed85cb3..3e1d9b5c8 100644 --- a/state/request_options.go +++ b/state/request_options.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "fmt" ) @@ -68,11 +69,11 @@ func validateConsistencyOption(c string) error { } // SetWithOptions handles SetRequest with request options. -func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error { - return method(req) +func SetWithOptions(ctx context.Context, method func(ctx context.Context, req *SetRequest) error, req *SetRequest) error { + return method(ctx, req) } // DeleteWithOptions handles DeleteRequest with options. -func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error { - return method(req) +func DeleteWithOptions(ctx context.Context, method func(ctx context.Context, req *DeleteRequest) error, req *DeleteRequest) error { + return method(ctx, req) } diff --git a/state/request_options_test.go b/state/request_options_test.go index 2bf8e72f3..914dcdce7 100644 --- a/state/request_options_test.go +++ b/state/request_options_test.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -23,7 +24,7 @@ import ( func TestSetRequestWithOptions(t *testing.T) { t.Run("set with default options", func(t *testing.T) { counter := 0 - SetWithOptions(func(req *SetRequest) error { + SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error { counter++ return nil @@ -33,7 +34,7 @@ func TestSetRequestWithOptions(t *testing.T) { t.Run("set with no explicit options", func(t *testing.T) { counter := 0 - SetWithOptions(func(req *SetRequest) error { + SetWithOptions(context.TODO(), func(ctx context.Context, req *SetRequest) error { counter++ return nil diff --git a/state/rethinkdb/rethinkdb.go b/state/rethinkdb/rethinkdb.go index 9a43389b4..c36db6341 100644 --- a/state/rethinkdb/rethinkdb.go +++ b/state/rethinkdb/rethinkdb.go @@ -14,6 +14,7 @@ limitations under the License. package rethinkdb import ( + "context" "encoding/json" "io" "strconv" @@ -147,7 +148,7 @@ func tableExists(arr []string, table string) bool { } // Get retrieves a RethinkDB KV item. -func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *RethinkDB) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if req == nil || req.Key == "" { return nil, errors.New("invalid state request, missing key") } @@ -187,22 +188,22 @@ func (s *RethinkDB) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *RethinkDB) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *RethinkDB) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with bulk get for performance return false, nil, nil } // Set saves a state KV item. -func (s *RethinkDB) Set(req *state.SetRequest) error { +func (s *RethinkDB) Set(ctx context.Context, req *state.SetRequest) error { if req == nil || req.Key == "" || req.Value == nil { return errors.New("invalid state request, key and value required") } - return s.BulkSet([]state.SetRequest{*req}) + return s.BulkSet(ctx, []state.SetRequest{*req}) } // BulkSet performs a bulk save operation. -func (s *RethinkDB) BulkSet(req []state.SetRequest) error { +func (s *RethinkDB) BulkSet(ctx context.Context, req []state.SetRequest) error { docs := make([]*stateRecord, len(req)) for i, v := range req { var etag string @@ -257,16 +258,16 @@ func (s *RethinkDB) archive(changes []r.ChangeResponse) error { } // Delete performes a RethinkDB KV delete operation. -func (s *RethinkDB) Delete(req *state.DeleteRequest) error { +func (s *RethinkDB) Delete(ctx context.Context, req *state.DeleteRequest) error { if req == nil || req.Key == "" { return errors.New("invalid request, missing key") } - return s.BulkDelete([]state.DeleteRequest{*req}) + return s.BulkDelete(ctx, []state.DeleteRequest{*req}) } // BulkDelete performs a bulk delete operation. -func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error { +func (s *RethinkDB) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { list := make([]string, 0) for _, d := range req { list = append(list, d.Key) @@ -282,7 +283,7 @@ func (s *RethinkDB) BulkDelete(req []state.DeleteRequest) error { } // Multi performs multiple operations. -func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error { +func (s *RethinkDB) Multi(ctx context.Context, req *state.TransactionalStateRequest) error { upserts := make([]state.SetRequest, 0) deletes := make([]state.DeleteRequest, 0) @@ -306,11 +307,11 @@ func (s *RethinkDB) Multi(req *state.TransactionalStateRequest) error { } // best effort, no transacts supported - if err := s.BulkSet(upserts); err != nil { + if err := s.BulkSet(ctx, upserts); err != nil { return errors.Wrap(err, "error saving records to the database") } - if err := s.BulkDelete(deletes); err != nil { + if err := s.BulkDelete(ctx, deletes); err != nil { return errors.Wrap(err, "error deleting records to the database") } diff --git a/state/rethinkdb/rethinkdb_test.go b/state/rethinkdb/rethinkdb_test.go index c29439069..d7e804ae9 100644 --- a/state/rethinkdb/rethinkdb_test.go +++ b/state/rethinkdb/rethinkdb_test.go @@ -14,6 +14,7 @@ limitations under the License. package rethinkdb import ( + "context" "encoding/json" "fmt" "os" @@ -86,12 +87,12 @@ func TestRethinkDBStateStore(t *testing.T) { d := &testObj{F1: "test", F2: 1, F3: time.Now().UTC()} k := fmt.Sprintf("ids-%d", time.Now().UnixNano()) - if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { + if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get set data and compare - resp, err := db.Get(&state.GetRequest{Key: k}) + resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k}) assert.Nil(t, err) d2 := testGetTestObj(t, resp) assert.NotNil(t, d2) @@ -103,12 +104,12 @@ func TestRethinkDBStateStore(t *testing.T) { d2.F2 = 2 d2.F3 = time.Now().UTC() tag := fmt.Sprintf("hash-%d", time.Now().UnixNano()) - if err = db.Set(&state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil { + if err = db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d2, ETag: &tag}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get updated data and compare - resp2, err := db.Get(&state.GetRequest{Key: k}) + resp2, err := db.Get(context.TODO(), &state.GetRequest{Key: k}) assert.Nil(t, err) d3 := testGetTestObj(t, resp2) assert.NotNil(t, d3) @@ -117,7 +118,7 @@ func TestRethinkDBStateStore(t *testing.T) { assert.Equal(t, d2.F3.Format(time.RFC3339), d3.F3.Format(time.RFC3339)) // delete data - if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { + if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil { t.Fatalf("error on data deletion: %v", err) } }) @@ -127,19 +128,19 @@ func TestRethinkDBStateStore(t *testing.T) { d := []byte("test") k := fmt.Sprintf("idb-%d", time.Now().UnixNano()) - if err := db.Set(&state.SetRequest{Key: k, Value: d}); err != nil { + if err := db.Set(context.TODO(), &state.SetRequest{Key: k, Value: d}); err != nil { t.Fatalf("error setting data to db: %v", err) } // get set data and compare - resp, err := db.Get(&state.GetRequest{Key: k}) + resp, err := db.Get(context.TODO(), &state.GetRequest{Key: k}) assert.Nil(t, err) assert.NotNil(t, resp) assert.NotNil(t, resp.Data) assert.Equal(t, string(d), string(resp.Data)) // delete data - if err := db.Delete(&state.DeleteRequest{Key: k}); err != nil { + if err := db.Delete(context.TODO(), &state.DeleteRequest{Key: k}); err != nil { t.Fatalf("error on data deletion: %v", err) } }) @@ -177,26 +178,26 @@ func testBulk(t *testing.T, db *RethinkDB, i int) { } // bulk set it - if err := db.BulkSet(setList); err != nil { + if err := db.BulkSet(context.TODO(), setList); err != nil { t.Fatalf("error setting data to db: %v -- run %d", err, i) } // check for the data for _, v := range deleteList { - resp, err := db.Get(&state.GetRequest{Key: v.Key}) + resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key}) assert.Nilf(t, err, " -- run %d", i) assert.NotNil(t, resp) assert.NotNil(t, resp.Data) } // delete data - if err := db.BulkDelete(deleteList); err != nil { + if err := db.BulkDelete(context.TODO(), deleteList); err != nil { t.Fatalf("error on data deletion: %v -- run %d", err, i) } // check for the data NOT being there for _, v := range deleteList { - resp, err := db.Get(&state.GetRequest{Key: v.Key}) + resp, err := db.Get(context.TODO(), &state.GetRequest{Key: v.Key}) assert.Nilf(t, err, " -- run %d", i) assert.NotNil(t, resp) assert.Nil(t, resp.Data) @@ -224,7 +225,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { for i := 0; i < numOfRecords; i++ { list[i] = state.SetRequest{Key: fmt.Sprintf(recordIDFormat, i), Value: d} } - if err := db.BulkSet(list); err != nil { + if err := db.BulkSet(context.TODO(), list); err != nil { t.Fatalf("error setting multi to db: %v", err) } @@ -258,19 +259,19 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { } // execute multi - if err := db.Multi(req); err != nil { + if err := db.Multi(context.TODO(), req); err != nil { t.Fatalf("error setting multi to db: %v", err) } // the one not deleted should be still there - m1, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)}) + m1, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 1)}) assert.Nil(t, err) assert.NotNil(t, m1) assert.NotNil(t, m1.Data) assert.Equal(t, string(d2), string(m1.Data)) // the one deleted should not - m2, err := db.Get(&state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)}) + m2, err := db.Get(context.TODO(), &state.GetRequest{Key: fmt.Sprintf(recordIDFormat, 3)}) assert.Nil(t, err) assert.NotNil(t, m2) assert.Nil(t, m2.Data) diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 7caeff835..f8d117760 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -14,6 +14,7 @@ limitations under the License. package sqlserver import ( + "context" "database/sql" "encoding/hex" "encoding/json" @@ -338,7 +339,7 @@ func (s *SQLServer) Features() []state.Feature { } // Multi performs multiple updates on a Sql server store. -func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error { +func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { tx, err := s.db.Begin() if err != nil { return err @@ -353,7 +354,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error { return err } - err = s.executeSet(tx, &setReq) + err = s.executeSet(ctx, tx, &setReq) if err != nil { tx.Rollback() return err @@ -366,7 +367,7 @@ func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error { return err } - err = s.executeDelete(tx, &delReq) + err = s.executeDelete(ctx, tx, &delReq) if err != nil { tx.Rollback() return err @@ -410,11 +411,11 @@ func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.Del } // Delete removes an entity from the store. -func (s *SQLServer) Delete(req *state.DeleteRequest) error { - return s.executeDelete(s.db, req) +func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error { + return s.executeDelete(ctx, s.db, req) } -func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error { +func (s *SQLServer) executeDelete(ctx context.Context, db dbExecutor, req *state.DeleteRequest) error { var err error var res sql.Result if req.ETag != nil { @@ -424,9 +425,9 @@ func (s *SQLServer) executeDelete(db dbExecutor, req *state.DeleteRequest) error return state.NewETagError(state.ETagInvalid, err) } - res, err = db.Exec(s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b)) + res, err = db.ExecContext(ctx, s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b)) } else { - res, err = db.Exec(s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key)) + res, err = db.ExecContext(ctx, s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key)) } // err represents errors thrown by the stored procedure or the database itself @@ -456,13 +457,13 @@ type TvpDeleteTableStringKey struct { } // BulkDelete removes multiple entries from the store. -func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error { +func (s *SQLServer) BulkDelete(ctx context.Context, req []state.DeleteRequest) error { tx, err := s.db.Begin() if err != nil { return err } - err = s.executeBulkDelete(tx, req) + err = s.executeBulkDelete(ctx, tx, req) if err != nil { tx.Rollback() @@ -474,7 +475,7 @@ func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error { return nil } -func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) error { +func (s *SQLServer) executeBulkDelete(ctx context.Context, db dbExecutor, req []state.DeleteRequest) error { values := make([]TvpDeleteTableStringKey, len(req)) for i, d := range req { var etag []byte @@ -493,7 +494,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) Value: values, } - res, err := db.Exec(s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete)) + res, err := db.ExecContext(ctx, s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete)) if err != nil { return err } @@ -513,7 +514,7 @@ func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) } // Get returns an entity from store. -func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key)) if err != nil { return nil, err @@ -545,21 +546,21 @@ func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *SQLServer) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { return false, nil, nil } // Set adds/updates an entity on store. -func (s *SQLServer) Set(req *state.SetRequest) error { - return s.executeSet(s.db, req) +func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error { + return s.executeSet(ctx, s.db, req) } // dbExecutor implements a common functionality implemented by db or tx. type dbExecutor interface { - Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } -func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error { +func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error { var err error var bytes []byte bytes, err = utils.Marshal(req.Value, json.Marshal) @@ -578,9 +579,9 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error { var res sql.Result if req.Options.Concurrency == state.FirstWrite { - res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1)) + res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1)) } else { - res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0)) + res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0)) } if err != nil { @@ -604,14 +605,14 @@ func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error { } // BulkSet adds/updates multiple entities on store. -func (s *SQLServer) BulkSet(req []state.SetRequest) error { +func (s *SQLServer) BulkSet(ctx context.Context, req []state.SetRequest) error { tx, err := s.db.Begin() if err != nil { return err } for i := range req { - err = s.executeSet(tx, &req[i]) + err = s.executeSet(ctx, tx, &req[i]) if err != nil { tx.Rollback() diff --git a/state/sqlserver/sqlserver_integration_test.go b/state/sqlserver/sqlserver_integration_test.go index dc8cd0346..24f009cd2 100644 --- a/state/sqlserver/sqlserver_integration_test.go +++ b/state/sqlserver/sqlserver_integration_test.go @@ -15,6 +15,7 @@ limitations under the License. package sqlserver import ( + "context" "database/sql" "encoding/json" "fmt" @@ -130,7 +131,7 @@ func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string) } func assertUserExists(t *testing.T, store *SQLServer, key string) (user, string) { - getRes, err := store.Get(&state.GetRequest{Key: key}) + getRes, err := store.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Nil(t, err) assert.NotNil(t, getRes) assert.NotNil(t, getRes.Data, "No data was returned") @@ -153,7 +154,7 @@ func assertLoadedUserIsEqual(t *testing.T, store *SQLServer, key string, expecte } func assertUserDoesNotExist(t *testing.T, store *SQLServer, key string) { - _, err := store.Get(&state.GetRequest{Key: key}) + _, err := store.Get(context.TODO(), &state.GetRequest{Key: key}) assert.Nil(t, err) } @@ -224,14 +225,14 @@ func testSingleOperations(t *testing.T) { assertUserDoesNotExist(t, store, john.ID) // Save and read - err := store.Set(&state.SetRequest{Key: john.ID, Value: john}) + err := store.Set(context.TODO(), &state.SetRequest{Key: john.ID, Value: john}) assert.Nil(t, err) johnV1, etagFromInsert := assertLoadedUserIsEqual(t, store, john.ID, john) // Update with ETAG waterJohn := johnV1 waterJohn.FavoriteBeverage = "Water" - err = store.Set(&state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert}) + err = store.Set(context.TODO(), &state.SetRequest{Key: waterJohn.ID, Value: waterJohn, ETag: &etagFromInsert}) assert.Nil(t, err) // Get updated @@ -240,7 +241,7 @@ func testSingleOperations(t *testing.T) { // Update without ETAG noEtagJohn := johnV2 noEtagJohn.FavoriteBeverage = "No Etag John" - err = store.Set(&state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn}) + err = store.Set(context.TODO(), &state.SetRequest{Key: noEtagJohn.ID, Value: noEtagJohn}) assert.Nil(t, err) // 7. Get updated @@ -249,17 +250,17 @@ func testSingleOperations(t *testing.T) { // 8. Update with invalid ETAG should fail failedJohn := johnV3 failedJohn.FavoriteBeverage = "Will not work" - err = store.Set(&state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert}) + err = store.Set(context.TODO(), &state.SetRequest{Key: failedJohn.ID, Value: failedJohn, ETag: &etagFromInsert}) assert.NotNil(t, err) _, etag := assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) // 9. Delete with invalid ETAG should fail - err = store.Delete(&state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag}) + err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV3.ID, ETag: &invEtag}) assert.NotNil(t, err) assertLoadedUserIsEqual(t, store, johnV3.ID, johnV3) // 10. Delete with valid ETAG - err = store.Delete(&state.DeleteRequest{Key: johnV2.ID, ETag: &etag}) + err = store.Delete(context.TODO(), &state.DeleteRequest{Key: johnV2.ID, ETag: &etag}) assert.Nil(t, err) assertUserDoesNotExist(t, store, johnV2.ID) @@ -273,7 +274,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) { u := user{uuid.New().String(), "John", "Coffee"} invEtag := invalidEtag - err := store.Set(&state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u, ETag: &invEtag}) assert.NotNil(t, err) } @@ -281,7 +282,7 @@ func testSetNewRecordWithInvalidEtagShouldFail(t *testing.T) { func testIndexedProperties(t *testing.T) { store := getTestStore(t, `[{ "column":"FavoriteBeverage", "property":"FavoriteBeverage", "type":"nvarchar(100)"}, { "column":"PetsCount", "property":"PetsCount", "type": "INTEGER"}]`) - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: "1", Value: userWithPets{user{"1", "John", "Coffee"}, 3}}, {Key: "2", Value: userWithPets{user{"2", "Laura", "Water"}, 1}}, {Key: "3", Value: userWithPets{user{"3", "Carl", "Beer"}, 0}}, @@ -343,7 +344,7 @@ func testMultiOperations(t *testing.T) { bulkSet[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(bulkSet) + err := store.BulkSet(context.TODO(), bulkSet) assert.Nil(t, err) assertUserCountIsEqualTo(t, store, len(initialUsers)) @@ -363,7 +364,7 @@ func testMultiOperations(t *testing.T) { modified := original.user modified.FavoriteBeverage = beverageTea - localErr := store.Multi(&state.TransactionalStateRequest{ + localErr := store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}}, @@ -386,7 +387,7 @@ func testMultiOperations(t *testing.T) { modified := toModify.user modified.FavoriteBeverage = beverageTea - err = store.Multi(&state.TransactionalStateRequest{ + err = store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}}, @@ -410,7 +411,7 @@ func testMultiOperations(t *testing.T) { modified := toModify.user modified.FavoriteBeverage = beverageTea - err = store.Multi(&state.TransactionalStateRequest{ + err = store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &toDelete.etag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &toModify.etag}}, @@ -431,7 +432,7 @@ func testMultiOperations(t *testing.T) { toInsert := user{keyGen.NextKey(), "Wont-be-inserted", "Beer"} invEtag := invalidEtag - err = store.Multi(&state.TransactionalStateRequest{ + err = store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: toInsert.ID, Value: toInsert}}, @@ -452,7 +453,7 @@ func testMultiOperations(t *testing.T) { modified.FavoriteBeverage = beverageTea invEtag := invalidEtag - err = store.Multi(&state.TransactionalStateRequest{ + err = store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID, ETag: &invEtag}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified}}, @@ -472,7 +473,7 @@ func testMultiOperations(t *testing.T) { modified.FavoriteBeverage = beverageTea invEtag := invalidEtag - err = store.Multi(&state.TransactionalStateRequest{ + err = store.Multi(context.TODO(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ {Operation: state.Delete, Request: state.DeleteRequest{Key: toDelete.ID}}, {Operation: state.Upsert, Request: state.SetRequest{Key: modified.ID, Value: modified, ETag: &invEtag}}, @@ -520,7 +521,7 @@ func testBulkSet(t *testing.T) { sets[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.Nil(t, err) totalUsers = len(sets) assertUserCountIsEqualTo(t, store, totalUsers) @@ -532,7 +533,7 @@ func testBulkSet(t *testing.T) { modified.FavoriteBeverage = beverageTea toInsert := user{keyGen.NextKey(), "Maria", "Wine"} - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: modified.ID, Value: modified, ETag: &toModifyETag}, {Key: toInsert.ID, Value: toInsert}, }) @@ -551,7 +552,7 @@ func testBulkSet(t *testing.T) { modified.FavoriteBeverage = beverageTea toInsert := user{keyGen.NextKey(), "Tony", "Milk"} - err := store.BulkSet([]state.SetRequest{ + err := store.BulkSet(context.TODO(), []state.SetRequest{ {Key: modified.ID, Value: modified}, {Key: toInsert.ID, Value: toInsert}, }) @@ -578,7 +579,7 @@ func testBulkSet(t *testing.T) { {Key: modified.ID, Value: modified, ETag: &invEtag}, } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.NotNil(t, err) assertUserCountIsEqualTo(t, store, totalUsers) assertUserDoesNotExist(t, store, toInsert1.ID) @@ -621,7 +622,7 @@ func testBulkDelete(t *testing.T) { for i, u := range initialUsers { sets[i] = state.SetRequest{Key: u.ID, Value: u} } - err := store.BulkSet(sets) + err := store.BulkSet(context.TODO(), sets) assert.Nil(t, err) totalUsers := len(initialUsers) assertUserCountIsEqualTo(t, store, totalUsers) @@ -631,7 +632,7 @@ func testBulkDelete(t *testing.T) { t.Run("Delete 2 items without etag should work", func(t *testing.T) { deleted1 := initialUsers[userIndex].ID deleted2 := initialUsers[userIndex+1].ID - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1}, {Key: deleted2}, }) @@ -648,7 +649,7 @@ func testBulkDelete(t *testing.T) { deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted2, deleted2Etag := assertUserExists(t, store, initialUsers[userIndex+1].ID) - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID, ETag: &deleted2Etag}, }) @@ -665,7 +666,7 @@ func testBulkDelete(t *testing.T) { deleted1, deleted1Etag := assertUserExists(t, store, initialUsers[userIndex].ID) deleted2 := initialUsers[userIndex+1] - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID}, }) @@ -683,7 +684,7 @@ func testBulkDelete(t *testing.T) { deleted2 := initialUsers[userIndex+1] invEtag := invalidEtag - err := store.BulkDelete([]state.DeleteRequest{ + err := store.BulkDelete(context.TODO(), []state.DeleteRequest{ {Key: deleted1.ID, ETag: &deleted1Etag}, {Key: deleted2.ID, ETag: &invEtag}, }) @@ -703,7 +704,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) { store := getTestStore(t, "") u := user{"1", "John", "Coffee"} - err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u}) assert.Nil(t, err) var originalInsertTime time.Time @@ -725,7 +726,7 @@ func testInsertAndUpdateSetRecordDates(t *testing.T) { modified := u modified.FavoriteBeverage = beverageTea - err = store.Set(&state.SetRequest{Key: modified.ID, Value: modified}) + err = store.Set(context.TODO(), &state.SetRequest{Key: modified.ID, Value: modified}) assert.Nil(t, err) assertDBQuery(t, store, getUserTsql, func(t *testing.T, rows *sql.Rows) { assert.True(t, rows.Next()) @@ -749,7 +750,7 @@ func testConcurrentSets(t *testing.T) { store := getTestStore(t, "") u := user{"1", "John", "Coffee"} - err := store.Set(&state.SetRequest{Key: u.ID, Value: u}) + err := store.Set(context.TODO(), &state.SetRequest{Key: u.ID, Value: u}) assert.Nil(t, err) _, etag := assertLoadedUserIsEqual(t, store, u.ID, u) @@ -766,7 +767,7 @@ func testConcurrentSets(t *testing.T) { defer wc.Done() modified := user{"1", "John", beverageTea} - err := store.Set(&state.SetRequest{Key: id, Value: modified, ETag: &etag}) + err := store.Set(context.TODO(), &state.SetRequest{Key: id, Value: modified, ETag: &etag}) if err != nil { atomic.AddInt32(&totalErrors, 1) } else { diff --git a/state/store.go b/state/store.go index a61df3251..592ffeb34 100644 --- a/state/store.go +++ b/state/store.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "fmt" "github.com/dapr/components-contrib/health" @@ -24,9 +25,9 @@ type Store interface { BulkStore Init(metadata Metadata) error Features() []Feature - Delete(req *DeleteRequest) error - Get(req *GetRequest) (*GetResponse, error) - Set(req *SetRequest) error + Delete(ctx context.Context, req *DeleteRequest) error + Get(ctx context.Context, req *GetRequest) (*GetResponse, error) + Set(ctx context.Context, req *SetRequest) error } func Ping(store Store) error { @@ -40,9 +41,9 @@ func Ping(store Store) error { // BulkStore is an interface to perform bulk operations on store. type BulkStore interface { - BulkDelete(req []DeleteRequest) error - BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) - BulkSet(req []SetRequest) error + BulkDelete(ctx context.Context, req []DeleteRequest) error + BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) + BulkSet(ctx context.Context, req []SetRequest) error } // DefaultBulkStore is a default implementation of BulkStore. @@ -64,16 +65,16 @@ func (b *DefaultBulkStore) Features() []Feature { } // BulkGet performs a bulks get operations. -func (b *DefaultBulkStore) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { +func (b *DefaultBulkStore) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) { // by default, the store doesn't support bulk get // return false so daprd will fallback to call get() method one by one return false, nil, nil } // BulkSet performs a bulks save operation. -func (b *DefaultBulkStore) BulkSet(req []SetRequest) error { +func (b *DefaultBulkStore) BulkSet(ctx context.Context, req []SetRequest) error { for i := range req { - err := b.s.Set(&req[i]) + err := b.s.Set(ctx, &req[i]) if err != nil { return err } @@ -83,9 +84,9 @@ func (b *DefaultBulkStore) BulkSet(req []SetRequest) error { } // BulkDelete performs a bulk delete operation. -func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error { +func (b *DefaultBulkStore) BulkDelete(ctx context.Context, req []DeleteRequest) error { for i := range req { - err := b.s.Delete(&req[i]) + err := b.s.Delete(ctx, &req[i]) if err != nil { return err } @@ -96,5 +97,5 @@ func (b *DefaultBulkStore) BulkDelete(req []DeleteRequest) error { // Querier is an interface to execute queries. type Querier interface { - Query(req *QueryRequest) (*QueryResponse, error) + Query(ctx context.Context, req *QueryRequest) (*QueryResponse, error) } diff --git a/state/store_test.go b/state/store_test.go index b0d0b2bf6..b04ae474c 100644 --- a/state/store_test.go +++ b/state/store_test.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -26,22 +27,22 @@ func TestStore_withDefaultBulkImpl(t *testing.T) { require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + store.Get(context.TODO(), &GetRequest{}) + store.Set(context.TODO(), &SetRequest{}) + store.Delete(context.TODO(), &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, responses, err := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, responses, err := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}}) require.Equal(t, false, bulkGet) require.Equal(t, 0, len(responses)) require.NoError(t, err) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}}) require.Equal(t, 3+4, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 3+4+5, s.count) require.Equal(t, 0, s.bulkCount) } @@ -52,20 +53,20 @@ func TestStore_withCustomisedBulkImpl_notSupportBulkGet(t *testing.T) { require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + store.Get(context.TODO(), &GetRequest{}) + store.Set(context.TODO(), &SetRequest{}) + store.Delete(context.TODO(), &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}}) require.Equal(t, false, bulkGet) require.Equal(t, 6, s.count) require.Equal(t, 0, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}}) require.Equal(t, 6, s.count) require.Equal(t, 1, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 6, s.count) require.Equal(t, 2, s.bulkCount) } @@ -76,20 +77,20 @@ func TestStore_withCustomisedBulkImpl_supportBulkGet(t *testing.T) { require.Equal(t, s.count, 0) require.Equal(t, s.bulkCount, 0) - store.Get(&GetRequest{}) - store.Set(&SetRequest{}) - store.Delete(&DeleteRequest{}) + store.Get(context.TODO(), &GetRequest{}) + store.Set(context.TODO(), &SetRequest{}) + store.Delete(context.TODO(), &DeleteRequest{}) require.Equal(t, 3, s.count) require.Equal(t, 0, s.bulkCount) - bulkGet, _, _ := store.BulkGet([]GetRequest{{}, {}, {}}) + bulkGet, _, _ := store.BulkGet(context.TODO(), []GetRequest{{}, {}, {}}) require.Equal(t, true, bulkGet) require.Equal(t, 3, s.count) require.Equal(t, 1, s.bulkCount) - store.BulkSet([]SetRequest{{}, {}, {}, {}}) + store.BulkSet(context.TODO(), []SetRequest{{}, {}, {}, {}}) require.Equal(t, 3, s.count) require.Equal(t, 2, s.bulkCount) - store.BulkDelete([]DeleteRequest{{}, {}, {}, {}, {}}) + store.BulkDelete(context.TODO(), []DeleteRequest{{}, {}, {}, {}, {}}) require.Equal(t, 3, s.count) require.Equal(t, 3, s.bulkCount) } @@ -110,19 +111,19 @@ func (s *Store1) Init(metadata Metadata) error { return nil } -func (s *Store1) Delete(req *DeleteRequest) error { +func (s *Store1) Delete(ctx context.Context, req *DeleteRequest) error { s.count++ return nil } -func (s *Store1) Get(req *GetRequest) (*GetResponse, error) { +func (s *Store1) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) { s.count++ return &GetResponse{}, nil } -func (s *Store1) Set(req *SetRequest) error { +func (s *Store1) Set(ctx context.Context, req *SetRequest) error { s.count++ return nil @@ -145,25 +146,25 @@ func (s *Store2) Features() []Feature { return nil } -func (s *Store2) Delete(req *DeleteRequest) error { +func (s *Store2) Delete(ctx context.Context, req *DeleteRequest) error { s.count++ return nil } -func (s *Store2) Get(req *GetRequest) (*GetResponse, error) { +func (s *Store2) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) { s.count++ return &GetResponse{}, nil } -func (s *Store2) Set(req *SetRequest) error { +func (s *Store2) Set(ctx context.Context, req *SetRequest) error { s.count++ return nil } -func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { +func (s *Store2) BulkGet(ctx context.Context, req []GetRequest) (bool, []BulkGetResponse, error) { if s.supportBulkGet { s.bulkCount++ @@ -175,13 +176,13 @@ func (s *Store2) BulkGet(req []GetRequest) (bool, []BulkGetResponse, error) { return false, nil, nil } -func (s *Store2) BulkSet(req []SetRequest) error { +func (s *Store2) BulkSet(ctx context.Context, req []SetRequest) error { s.bulkCount++ return nil } -func (s *Store2) BulkDelete(req []DeleteRequest) error { +func (s *Store2) BulkDelete(ctx context.Context, req []DeleteRequest) error { s.bulkCount++ return nil diff --git a/state/zookeeper/zk.go b/state/zookeeper/zk.go index 3914f0c26..69b368469 100644 --- a/state/zookeeper/zk.go +++ b/state/zookeeper/zk.go @@ -14,6 +14,7 @@ limitations under the License. package zookeeper import ( + "context" "errors" "path" "strconv" @@ -161,7 +162,7 @@ func (s *StateStore) Features() []state.Feature { } // Get retrieves state from Zookeeper with a key. -func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { +func (s *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { value, stat, err := s.conn.Get(s.prefixedKey(req.Key)) if err != nil { if errors.Is(err, zk.ErrNoNode) { @@ -178,19 +179,19 @@ func (s *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } // BulkGet performs a bulks get operations. -func (s *StateStore) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) { +func (s *StateStore) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) { // TODO: replace with Multi for performance return false, nil, nil } // Delete performs a delete operation. -func (s *StateStore) Delete(req *state.DeleteRequest) error { +func (s *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { r, err := s.newDeleteRequest(req) if err != nil { return err } - return state.DeleteWithOptions(func(req *state.DeleteRequest) error { + return state.DeleteWithOptions(ctx, func(ctx context.Context, req *state.DeleteRequest) error { err := s.conn.Delete(r.Path, r.Version) if errors.Is(err, zk.ErrNoNode) { return nil @@ -209,7 +210,7 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error { } // BulkDelete performs a bulk delete operation. -func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error { +func (s *StateStore) BulkDelete(ctx context.Context, reqs []state.DeleteRequest) error { ops := make([]interface{}, 0, len(reqs)) for i := range reqs { @@ -236,13 +237,13 @@ func (s *StateStore) BulkDelete(reqs []state.DeleteRequest) error { } // Set saves state into Zookeeper. -func (s *StateStore) Set(req *state.SetRequest) error { +func (s *StateStore) Set(ctx context.Context, req *state.SetRequest) error { r, err := s.newSetDataRequest(req) if err != nil { return err } - return state.SetWithOptions(func(req *state.SetRequest) error { + return state.SetWithOptions(ctx, func(ctx context.Context, req *state.SetRequest) error { _, err = s.conn.Set(r.Path, r.Data, r.Version) if errors.Is(err, zk.ErrNoNode) { @@ -262,7 +263,7 @@ func (s *StateStore) Set(req *state.SetRequest) error { } // BulkSet performs a bulks save operation. -func (s *StateStore) BulkSet(reqs []state.SetRequest) error { +func (s *StateStore) BulkSet(ctx context.Context, reqs []state.SetRequest) error { ops := make([]interface{}, 0, len(reqs)) for i := range reqs { diff --git a/state/zookeeper/zk_test.go b/state/zookeeper/zk_test.go index d25226db7..ac1e22868 100644 --- a/state/zookeeper/zk_test.go +++ b/state/zookeeper/zk_test.go @@ -14,6 +14,7 @@ limitations under the License. package zookeeper import ( + "context" "fmt" "testing" "time" @@ -80,7 +81,7 @@ func TestGet(t *testing.T) { t.Run("With key exists", func(t *testing.T) { conn.EXPECT().Get("foo").Return([]byte("bar"), &zk.Stat{Version: 123}, nil).Times(1) - res, err := s.Get(&state.GetRequest{Key: "foo"}) + res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"}) assert.NotNil(t, res, "Key must be exists") assert.Equal(t, "bar", string(res.Data), "Value must be equals") assert.Equal(t, ptr.String("123"), res.ETag, "ETag must be equals") @@ -90,7 +91,7 @@ func TestGet(t *testing.T) { t.Run("With key non-exists", func(t *testing.T) { conn.EXPECT().Get("foo").Return(nil, nil, zk.ErrNoNode).Times(1) - res, err := s.Get(&state.GetRequest{Key: "foo"}) + res, err := s.Get(context.TODO(), &state.GetRequest{Key: "foo"}) assert.Equal(t, &state.GetResponse{}, res, "Response must be empty") assert.NoError(t, err, "Non-existent key must not be treated as error") }) @@ -108,21 +109,21 @@ func TestDelete(t *testing.T) { t.Run("With key", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.NoError(t, err, "Key must be exists") }) t.Run("With key and version", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(123)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo", ETag: &etag}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo", ETag: &etag}) assert.NoError(t, err, "Key must be exists") }) t.Run("With key and concurrency", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(nil).Times(1) - err := s.Delete(&state.DeleteRequest{ + err := s.Delete(context.TODO(), &state.DeleteRequest{ Key: "foo", ETag: &etag, Options: state.DeleteStateOption{Concurrency: state.LastWrite}, @@ -133,14 +134,14 @@ func TestDelete(t *testing.T) { t.Run("With delete error", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrUnknown).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.EqualError(t, err, "zk: unknown error") }) t.Run("With delete and ignore NoNode error", func(t *testing.T) { conn.EXPECT().Delete("foo", int32(anyVersion)).Return(zk.ErrNoNode).Times(1) - err := s.Delete(&state.DeleteRequest{Key: "foo"}) + err := s.Delete(context.TODO(), &state.DeleteRequest{Key: "foo"}) assert.NoError(t, err, "Delete must be successful") }) } @@ -159,7 +160,7 @@ func TestBulkDelete(t *testing.T) { &zk.DeleteRequest{Path: "bar", Version: int32(anyVersion)}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.NoError(t, err, "Key must be exists") }) @@ -171,7 +172,7 @@ func TestBulkDelete(t *testing.T) { {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, }, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.Equal(t, err.(*multierror.Error).Errors, []error{zk.ErrUnknown, zk.ErrNoAuth}) }) t.Run("With keys and ignore NoNode error", func(t *testing.T) { @@ -182,7 +183,7 @@ func TestBulkDelete(t *testing.T) { {Error: zk.ErrNoNode}, {}, }, nil).Times(1) - err := s.BulkDelete([]state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) + err := s.BulkDelete(context.TODO(), []state.DeleteRequest{{Key: "foo"}, {Key: "bar"}}) assert.NoError(t, err, "Key must be exists") }) } @@ -201,19 +202,19 @@ func TestSet(t *testing.T) { t.Run("With key", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.NoError(t, err, "Key must be set") }) t.Run("With key and version", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(123)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar", ETag: &etag}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar", ETag: &etag}) assert.NoError(t, err, "Key must be set") }) t.Run("With key and concurrency", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(stat, nil).Times(1) - err := s.Set(&state.SetRequest{ + err := s.Set(context.TODO(), &state.SetRequest{ Key: "foo", Value: "bar", ETag: &etag, @@ -225,14 +226,14 @@ func TestSet(t *testing.T) { t.Run("With error", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrUnknown).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.EqualError(t, err, "zk: unknown error") }) t.Run("With NoNode error and retry", func(t *testing.T) { conn.EXPECT().Set("foo", []byte("\"bar\""), int32(anyVersion)).Return(nil, zk.ErrNoNode).Times(1) conn.EXPECT().Create("foo", []byte("\"bar\""), int32(0), nil).Return("/foo", nil).Times(1) - err := s.Set(&state.SetRequest{Key: "foo", Value: "bar"}) + err := s.Set(context.TODO(), &state.SetRequest{Key: "foo", Value: "bar"}) assert.NoError(t, err, "Key must be create") }) } @@ -251,7 +252,7 @@ func TestBulkSet(t *testing.T) { &zk.SetDataRequest{Path: "bar", Data: []byte("\"foo\""), Version: int32(anyVersion)}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) @@ -266,7 +267,7 @@ func TestBulkSet(t *testing.T) { {Error: zk.ErrUnknown}, {Error: zk.ErrNoAuth}, }, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) @@ -283,7 +284,7 @@ func TestBulkSet(t *testing.T) { &zk.CreateRequest{Path: "foo", Data: []byte("\"bar\"")}, }).Return([]zk.MultiResponse{{}, {}}, nil).Times(1) - err := s.BulkSet([]state.SetRequest{ + err := s.BulkSet(context.TODO(), []state.SetRequest{ {Key: "foo", Value: "bar"}, {Key: "bar", Value: "foo"}, }) diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index a57bae49c..f266b859d 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -14,6 +14,7 @@ limitations under the License. package state import ( + "context" "encoding/json" "fmt" "sort" @@ -251,7 +252,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - err := statestore.Set(req) + err := statestore.Set(context.TODO(), req) assert.Nil(t, err) } } @@ -269,7 +270,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - res, err := statestore.Get(req) + res, err := statestore.Get(context.TODO(), req) assert.Nil(t, err) assertEquals(t, scenario.value, res) } @@ -290,7 +291,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St metadata.ContentType: contenttype.JSONContentType, metadata.QueryIndexName: "qIndx", } - resp, err := querier.Query(&req) + resp, err := querier.Query(context.TODO(), &req) assert.NoError(t, err) assert.Equal(t, len(scenario.results), len(resp.Results)) for i := range scenario.results { @@ -318,11 +319,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if len(scenario.contentType) != 0 { req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} } - err := statestore.Delete(req) + err := statestore.Delete(context.TODO(), req) assert.Nil(t, err, "no error expected while deleting %s", scenario.key) t.Logf("Checking value absence for %s", scenario.key) - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, }) assert.Nil(t, err, "no error expected while checking for absence for %s", scenario.key) @@ -344,14 +345,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) } } - err := statestore.BulkSet(bulk) + err := statestore.BulkSet(context.TODO(), bulk) assert.Nil(t, err) for _, scenario := range scenarios { if scenario.bulkOnly { t.Logf("Checking value presence for %s", scenario.key) // Data should have been inserted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, }) assert.Nil(t, err) @@ -372,12 +373,12 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) } } - err := statestore.BulkDelete(bulk) + err := statestore.BulkDelete(context.TODO(), bulk) assert.Nil(t, err) for _, req := range bulk { t.Logf("Checking value absence for %s", req.Key) - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: req.Key, }) assert.Nil(t, err) @@ -443,7 +444,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if scenario.transactionGroup == transactionGroup { t.Logf("Checking value presence for %s", scenario.key) // Data should have been inserted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, // For CosmosDB Metadata: map[string]string{ @@ -457,7 +458,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St if scenario.toBeDeleted && (scenario.transactionGroup == transactionGroup-1) { t.Logf("Checking value absence for %s", scenario.key) // Data should have been deleted at this point - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: scenario.key, // For CosmosDB Metadata: map[string]string{ @@ -487,7 +488,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } // prerequisite: key1 should be present - err := statestore.Set(&state.SetRequest{ + err := statestore.Set(context.TODO(), &state.SetRequest{ Key: firstKey, Value: firstValue, Metadata: partitionMetadata, @@ -495,14 +496,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.NoError(t, err, "set request should be successful") // prerequisite: key2 should not be present - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: secondKey, Metadata: partitionMetadata, }) assert.NoError(t, err, "delete request should be successful") // prerequisite: key3 should not be present - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: thirdKey, Metadata: partitionMetadata, }) @@ -558,7 +559,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St // Assert for k, v := range expected { - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: k, Metadata: partitionMetadata, }) @@ -585,20 +586,20 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.True(t, state.FeatureETag.IsPresent(features)) // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) require.Nil(t, err) // Set an object. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: firstValue, }) require.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) @@ -607,7 +608,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St etag := res.ETag // Try and update with wrong ETag, expect failure. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: secondValue, ETag: &fakeEtag, @@ -615,7 +616,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.NotNil(t, err) // Try and update with corect ETag, expect success. - err = statestore.Set(&state.SetRequest{ + err = statestore.Set(context.TODO(), &state.SetRequest{ Key: testKey, Value: secondValue, ETag: etag, @@ -623,7 +624,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.Nil(t, err) // Validate the set. - res, err = statestore.Get(&state.GetRequest{ + res, err = statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) require.Nil(t, err) @@ -632,14 +633,14 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St etag = res.ETag // Try and delete with wrong ETag, expect failure. - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, ETag: &fakeEtag, }) require.NotNil(t, err) // Try and delete with correct ETag, expect success. - err = statestore.Delete(&state.DeleteRequest{ + err = statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, ETag: etag, }) @@ -698,23 +699,23 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St for i, requestSet := range requestSets { t.Run(fmt.Sprintf("request set %d", i), func(t *testing.T) { // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) require.Nil(t, err) - err = statestore.Set(requestSet[0]) + err = statestore.Set(context.TODO(), requestSet[0]) require.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) require.Nil(t, err) assertEquals(t, firstValue, res) // Second write expect fail - err = statestore.Set(requestSet[1]) + err = statestore.Set(context.TODO(), requestSet[1]) require.NotNil(t, err) }) } @@ -731,16 +732,16 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } // Delete any potential object, it's important to start from a clean slate. - err := statestore.Delete(&state.DeleteRequest{ + err := statestore.Delete(context.TODO(), &state.DeleteRequest{ Key: testKey, }) require.Nil(t, err) - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) require.Nil(t, err) // Validate the set. - res, err := statestore.Get(&state.GetRequest{ + res, err := statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) require.Nil(t, err) @@ -757,11 +758,11 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St Consistency: state.Strong, }, } - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) require.Nil(t, err) // Validate the set. - res, err = statestore.Get(&state.GetRequest{ + res, err = statestore.Get(context.TODO(), &state.GetRequest{ Key: testKey, }) require.Nil(t, err) @@ -771,7 +772,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St request.ETag = etag // Second write expect fail - err = statestore.Set(request) + err = statestore.Set(context.TODO(), request) require.NotNil(t, err) }) }