From 964c2089777c14ef798d8c644b0416181dbfb8db Mon Sep 17 00:00:00 2001 From: Mike Brown Date: Thu, 18 Aug 2022 16:13:19 -0400 Subject: [PATCH] dynamodb state: add support for etags In order to safely support concurrent updates to existing state items clients should employ etags. (For reference please see tinyurl.com/5n83tnfp). Existing support for AWS DynamoDB does not include support for etags. This change introduces etag support for AWS DynamoDB utilizing conditional expressions (reference tinyurl.com/5du587m8). Signed-off-by: Mike Brown --- state/aws/dynamodb/dynamodb.go | 89 ++++++++++- state/aws/dynamodb/dynamodb_test.go | 230 +++++++++++++++++++++++++--- 2 files changed, 291 insertions(+), 28 deletions(-) diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 3aaae8d17..d29f8d7c1 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -14,6 +14,8 @@ limitations under the License. package dynamodb import ( + "crypto/rand" + "encoding/binary" "encoding/json" "fmt" "strconv" @@ -73,7 +75,7 @@ func (d *StateStore) Init(metadata state.Metadata) error { // Features returns the features available in this state store. func (d *StateStore) Features() []state.Feature { - return nil + return []state.Feature{state.FeatureETag} } // Get retrieves a dynamoDB item. @@ -115,9 +117,19 @@ func (d *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) { } } - return &state.GetResponse{ + resp := &state.GetResponse{ Data: []byte(output), - }, nil + } + + var etag string + if etagVal, ok := result.Item["etag"]; ok { + if err = dynamodbattribute.Unmarshal(etagVal, &etag); err != nil { + return nil, err + } + resp.ETag = &etag + } + + return resp, nil } // BulkGet performs a bulk get operations. @@ -138,17 +150,42 @@ func (d *StateStore) Set(req *state.SetRequest) error { TableName: &d.table, } - _, e := d.client.PutItem(input) + if req.ETag != nil && *req.ETag != "" { + condExpr := "etag = :etag" + input.ConditionExpression = &condExpr + exprAttrValues := make(map[string]*dynamodb.AttributeValue) + exprAttrValues[":etag"] = &dynamodb.AttributeValue{ + S: req.ETag, + } + input.ExpressionAttributeValues = exprAttrValues + } - return e + _, err = d.client.PutItem(input) + if err != nil { + switch cErr := err.(type) { + case *dynamodb.ConditionalCheckFailedException: + err = state.NewETagError(state.ETagMismatch, cErr) + } + } + + return err } // BulkSet performs a bulk set operation. func (d *StateStore) BulkSet(req []state.SetRequest) error { writeRequests := []*dynamodb.WriteRequest{} + if len(req) == 1 { + return d.Set(&req[0]) + } + for _, r := range req { r := r // avoid G601. + + if r.ETag != nil && *r.ETag != "" { + return fmt.Errorf("dynamodb error: BulkSet() does not support etags; please use Set() instead") + } + item, err := d.getItemFromReq(&r) if err != nil { return err @@ -183,7 +220,24 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error { }, TableName: aws.String(d.table), } + + if req.ETag != nil && *req.ETag != "" { + condExpr := "etag = :etag" + input.ConditionExpression = &condExpr + exprAttrValues := make(map[string]*dynamodb.AttributeValue) + exprAttrValues[":etag"] = &dynamodb.AttributeValue{ + S: req.ETag, + } + input.ExpressionAttributeValues = exprAttrValues + } + _, err := d.client.DeleteItem(input) + if err != nil { + switch cErr := err.(type) { + case *dynamodb.ConditionalCheckFailedException: + err = state.NewETagError(state.ETagMismatch, cErr) + } + } return err } @@ -192,7 +246,15 @@ func (d *StateStore) Delete(req *state.DeleteRequest) error { func (d *StateStore) BulkDelete(req []state.DeleteRequest) error { writeRequests := []*dynamodb.WriteRequest{} + if len(req) == 1 { + return d.Delete(&req[0]) + } + for _, r := range req { + if r.ETag != nil && *r.ETag != "" { + return fmt.Errorf("dynamodb error: BulkDelete() does not support etags; please use Delete() instead") + } + writeRequest := &dynamodb.WriteRequest{ DeleteRequest: &dynamodb.DeleteRequest{ Key: map[string]*dynamodb.AttributeValue{ @@ -255,6 +317,10 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb return nil, fmt.Errorf("dynamodb error: failed to parse ttlInSeconds: %s", err) } + newEtag, err := getRand64() + if err != nil { + return nil, fmt.Errorf("dynamodb error: failed to generate etag: %w", err) + } item := map[string]*dynamodb.AttributeValue{ "key": { S: aws.String(req.Key), @@ -262,6 +328,9 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb "value": { S: aws.String(value), }, + "etag": { + S: aws.String(strconv.FormatUint(newEtag, 16)), + }, } if ttl != nil { @@ -273,6 +342,16 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb return item, nil } +func getRand64() (uint64, error) { + randBuf := make([]byte, 8) + _, err := rand.Read(randBuf) + if err != nil { + return 0, err + } + + return binary.LittleEndian.Uint64(randBuf), nil +} + func (d *StateStore) marshalToString(v interface{}) (string, error) { if buf, ok := v.([]byte); ok { return string(buf), nil diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 9cfd47f62..b7552097c 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -108,6 +108,9 @@ func TestGet(t *testing.T) { "value": { S: aws.String("some value"), }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, }, }, nil }, @@ -123,6 +126,7 @@ func TestGet(t *testing.T) { out, err := ss.Get(req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) + assert.Equal(t, "1bdead4badc0ffee", *out.ETag) }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { ss := StateStore{ @@ -139,6 +143,9 @@ func TestGet(t *testing.T) { "testAttributeName": { N: aws.String("4074862051"), }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, }, }, nil }, @@ -155,6 +162,7 @@ func TestGet(t *testing.T) { out, err := ss.Get(req) assert.Nil(t, err) assert.Equal(t, []byte("some value"), out.Data) + assert.Equal(t, "1bdead4badc0ffee", *out.ETag) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { ss := StateStore{ @@ -171,6 +179,9 @@ func TestGet(t *testing.T) { "testAttributeName": { N: aws.String("35489251"), }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, }, }, nil }, @@ -187,6 +198,7 @@ func TestGet(t *testing.T) { out, err := ss.Get(req) assert.Nil(t, err) assert.Nil(t, out.Data) + assert.Nil(t, out.ETag) }) t.Run("Unsuccessfully get item", func(t *testing.T) { ss := StateStore{ @@ -227,6 +239,7 @@ func TestGet(t *testing.T) { out, err := ss.Get(req) assert.Nil(t, err) assert.Nil(t, out.Data) + assert.Nil(t, out.ETag) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { ss := StateStore{ @@ -264,15 +277,13 @@ func TestSet(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("key"), - }, - "value": { - S: aws.String(`{"Value":"value"}`), - }, - }, input.Item) - assert.Equal(t, len(input.Item), 2) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String("key"), + }, *input.Item["key"]) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(`{"Value":"value"}`), + }, *input.Item["value"]) + assert.Equal(t, len(input.Item), 3) return &dynamodb.PutItemOutput{ Attributes: map[string]*dynamodb.AttributeValue{ @@ -294,11 +305,89 @@ func TestSet(t *testing.T) { assert.Nil(t, err) }) + t.Run("Successfully set item with matching etag", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String("key"), + }, *input.Item["key"]) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(`{"Value":"value"}`), + }, *input.Item["value"]) + assert.Equal(t, "etag = :etag", *input.ConditionExpression) + assert.Equal(t, &dynamodb.AttributeValue{ + S: aws.String("1bdead4badc0ffee"), + }, input.ExpressionAttributeValues[":etag"]) + assert.Equal(t, len(input.Item), 3) + + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), + }, + }, + }, nil + }, + }, + } + etag := "1bdead4badc0ffee" + req := &state.SetRequest{ + ETag: &etag, + Key: "key", + Value: value{ + Value: "value", + }, + } + err := ss.Set(req) + assert.Nil(t, err) + }) + + t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String("key"), + }, *input.Item["key"]) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(`{"Value":"value"}`), + }, *input.Item["value"]) + assert.Equal(t, "etag = :etag", *input.ConditionExpression) + assert.Equal(t, &dynamodb.AttributeValue{ + S: aws.String("bogusetag"), + }, input.ExpressionAttributeValues[":etag"]) + assert.Equal(t, len(input.Item), 3) + + var checkErr dynamodb.ConditionalCheckFailedException + return nil, &checkErr + }, + }, + } + etag := "bogusetag" + req := &state.SetRequest{ + ETag: &etag, + Key: "key", + Value: value{ + Value: "value", + }, + } + + err := ss.Set(req) + assert.NotNil(t, err) + switch tagErr := err.(type) { + case *state.ETagError: + assert.Equal(t, tagErr.Kind(), state.ETagMismatch) + default: + assert.True(t, false) + } + }) + t.Run("Successfully set item with ttl = -1", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, len(input.Item), 3) + assert.Equal(t, len(input.Item), 4) result := DynamoDBItem{} dynamodbattribute.UnmarshalMap(input.Item, &result) assert.Equal(t, result.Key, "someKey") @@ -333,7 +422,7 @@ func TestSet(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, len(input.Item), 3) + assert.Equal(t, len(input.Item), 4) result := DynamoDBItem{} dynamodbattribute.UnmarshalMap(input.Item, &result) assert.Equal(t, result.Key, "someKey") @@ -386,15 +475,13 @@ func TestSet(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ PutItemFn: func(input *dynamodb.PutItemInput) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String(`{"Value":"someValue"}`), - }, - }, input.Item) - assert.Equal(t, len(input.Item), 2) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String("someKey"), + }, *input.Item["key"]) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(`{"Value":"someValue"}`), + }, *input.Item["value"]) + assert.Equal(t, len(input.Item), 3) return &dynamodb.PutItemOutput{ Attributes: map[string]*dynamodb.AttributeValue{ @@ -497,7 +584,16 @@ func TestBulkSet(t *testing.T) { }, }, } - assert.Equal(t, expected, input.RequestItems) + + for tbl := range expected { + for reqNum := range expected[tbl] { + expectedItem := expected[tbl][reqNum].PutRequest.Item + inputItem := input.RequestItems[tbl][reqNum].PutRequest.Item + + assert.Equal(t, expectedItem["key"], inputItem["key"]) + assert.Equal(t, expectedItem["value"], inputItem["value"]) + } + } return &dynamodb.BatchWriteItemOutput{ UnprocessedItems: map[string][]*dynamodb.WriteRequest{}, @@ -558,7 +654,15 @@ func TestBulkSet(t *testing.T) { }, }, } - assert.Equal(t, expected, input.RequestItems) + for tbl := range expected { + for reqNum := range expected[tbl] { + expectedItem := expected[tbl][reqNum].PutRequest.Item + inputItem := input.RequestItems[tbl][reqNum].PutRequest.Item + + assert.Equal(t, expectedItem["key"], inputItem["key"]) + assert.Equal(t, expectedItem["value"], inputItem["value"]) + } + } return &dynamodb.BatchWriteItemOutput{ UnprocessedItems: map[string][]*dynamodb.WriteRequest{}, @@ -625,7 +729,15 @@ func TestBulkSet(t *testing.T) { }, }, } - assert.Equal(t, expected, input.RequestItems) + for tbl := range expected { + for reqNum := range expected[tbl] { + expectedItem := expected[tbl][reqNum].PutRequest.Item + inputItem := input.RequestItems[tbl][reqNum].PutRequest.Item + + assert.Equal(t, expectedItem["key"], inputItem["key"]) + assert.Equal(t, expectedItem["value"], inputItem["value"]) + } + } return &dynamodb.BatchWriteItemOutput{ UnprocessedItems: map[string][]*dynamodb.WriteRequest{}, @@ -670,6 +782,12 @@ func TestBulkSet(t *testing.T) { Value: "value", }, }, + { + Key: "key", + Value: value{ + Value: "value", + }, + }, } err := ss.BulkSet(req) assert.NotNil(t, err) @@ -699,6 +817,69 @@ func TestDelete(t *testing.T) { assert.Nil(t, err) }) + t.Run("Successfully delete item with matching etag", func(t *testing.T) { + etag := "1bdead4badc0ffee" + req := &state.DeleteRequest{ + ETag: &etag, + Key: "key", + } + + ss := StateStore{ + client: &mockedDynamoDB{ + DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(req.Key), + }, + }, input.Key) + assert.Equal(t, "etag = :etag", *input.ConditionExpression) + assert.Equal(t, &dynamodb.AttributeValue{ + S: aws.String("1bdead4badc0ffee"), + }, input.ExpressionAttributeValues[":etag"]) + + return nil, nil + }, + }, + } + err := ss.Delete(req) + assert.Nil(t, err) + }) + + t.Run("Unsuccessfully delete item with mismatched etag", func(t *testing.T) { + etag := "bogusetag" + req := &state.DeleteRequest{ + ETag: &etag, + Key: "key", + } + + ss := StateStore{ + client: &mockedDynamoDB{ + DeleteItemFn: func(input *dynamodb.DeleteItemInput) (output *dynamodb.DeleteItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(req.Key), + }, + }, input.Key) + assert.Equal(t, "etag = :etag", *input.ConditionExpression) + assert.Equal(t, &dynamodb.AttributeValue{ + S: aws.String("bogusetag"), + }, input.ExpressionAttributeValues[":etag"]) + + var checkErr dynamodb.ConditionalCheckFailedException + return nil, &checkErr + }, + }, + } + err := ss.Delete(req) + assert.NotNil(t, err) + switch tagErr := err.(type) { + case *state.ETagError: + assert.Equal(t, tagErr.Kind(), state.ETagMismatch) + default: + assert.True(t, false) + } + }) + t.Run("Unsuccessfully delete item", func(t *testing.T) { ss := StateStore{ client: &mockedDynamoDB{ @@ -774,6 +955,9 @@ func TestBulkDelete(t *testing.T) { { Key: "key", }, + { + Key: "key", + }, } err := ss.BulkDelete(req) assert.NotNil(t, err)