diff --git a/pubsub/azure/eventhubs/eventhubs.go b/pubsub/azure/eventhubs/eventhubs.go index f391aefdf..512ff61a9 100644 --- a/pubsub/azure/eventhubs/eventhubs.go +++ b/pubsub/azure/eventhubs/eventhubs.go @@ -73,6 +73,9 @@ const ( sysPropIotHubEnqueuedTime = "iothub-enqueuedtime" sysPropMessageID = "message-id" + // Metadata field to ensure all Event Hub properties pass through + requireAllProperties = "requireAllProperties" + defaultMessageRetentionInDays = 1 defaultPartitionCount = 1 @@ -86,7 +89,7 @@ const ( maxPartitionCount = int32(1024) ) -func subscribeHandler(ctx context.Context, topic string, e *eventhub.Event, handler pubsub.Handler) error { +func subscribeHandler(ctx context.Context, topic string, getAllProperties bool, e *eventhub.Event, handler pubsub.Handler) error { res := pubsub.NewMessage{Data: e.Data, Topic: topic, Metadata: map[string]string{}} if e.SystemProperties.SequenceNumber != nil { res.Metadata[sysPropSequenceNumber] = strconv.FormatInt(*e.SystemProperties.SequenceNumber, 10) @@ -124,6 +127,16 @@ func subscribeHandler(ctx context.Context, topic string, e *eventhub.Event, hand if e.ID != "" { res.Metadata[sysPropMessageID] = e.ID } + // added properties if any ( includes application properties from iot-hub) + if getAllProperties { + if e.Properties != nil && len(e.Properties) > 0 { + for key, value := range e.Properties { + if str, ok := value.(string); ok { + res.Metadata[key] = str + } + } + } + } return handler(ctx, &res) } @@ -622,6 +635,14 @@ func (aeh *AzureEventHubs) Subscribe(subscribeCtx context.Context, req pubsub.Su return err } + getAllProperties := false + if req.Metadata[requireAllProperties] != "" { + getAllProperties, err = strconv.ParseBool(req.Metadata[requireAllProperties]) + if err != nil { + aeh.logger.Errorf("invalid value for metadata : %s . Error: %v.", requireAllProperties, err) + } + } + aeh.logger.Debugf("registering handler for topic %s", req.Topic) _, err = processor.RegisterHandler(subscribeCtx, func(_ context.Context, e *eventhub.Event) error { @@ -631,7 +652,7 @@ func (aeh *AzureEventHubs) Subscribe(subscribeCtx context.Context, req pubsub.Su retryerr := retry.NotifyRecover(func() error { aeh.logger.Debugf("Processing EventHubs event %s/%s", req.Topic, e.ID) - return subscribeHandler(subscribeCtx, req.Topic, e, handler) + return subscribeHandler(subscribeCtx, req.Topic, getAllProperties, e, handler) }, b, func(_ error, _ time.Duration) { aeh.logger.Warnf("Error processing EventHubs event: %s/%s. Retrying...", req.Topic, e.ID) }, func() { diff --git a/pubsub/azure/eventhubs/eventhubs_integration_test.go b/pubsub/azure/eventhubs/eventhubs_integration_test.go index 077e91ac9..285063345 100644 --- a/pubsub/azure/eventhubs/eventhubs_integration_test.go +++ b/pubsub/azure/eventhubs/eventhubs_integration_test.go @@ -46,6 +46,7 @@ const ( testStorageContainerName = "iothub-pubsub-integration-test" testTopic = "integration-test-topic" + applicationProperty = "applicationProperty" ) func createIotHubPubsubMetadata() pubsub.Metadata { @@ -86,8 +87,10 @@ func testReadIotHubEvents(t *testing.T) { } req := pubsub.SubscribeRequest{ - Topic: testTopic, // TODO: Handle Topic configuration after EventHubs pubsub rewrite #951 - Metadata: map[string]string{}, + Topic: testTopic, // TODO: Handle Topic configuration after EventHubs pubsub rewrite #951 + Metadata: map[string]string{ + "requireAllProperties": "true", + }, } err = eh.Subscribe(context.Background(), req, handler) assert.Nil(t, err) @@ -114,6 +117,9 @@ func testReadIotHubEvents(t *testing.T) { assert.Contains(t, r.Metadata, sysPropIotHubConnectionAuthMethod, "IoT device event missing: %s", sysPropIotHubConnectionAuthMethod) assert.Contains(t, r.Metadata, sysPropIotHubEnqueuedTime, "IoT device event missing: %s", sysPropIotHubEnqueuedTime) assert.Contains(t, r.Metadata, sysPropMessageID, "IoT device event missing: %s", sysPropMessageID) + + // Verify sent custom application property is received in IoT Hub device event metadata + assert.Contains(t, r.Metadata, applicationProperty, "IoT device event missing: %s", applicationProperty) } eh.Close() diff --git a/pubsub/jetstream/jetstream.go b/pubsub/jetstream/jetstream.go index bb886485b..3965cb0a5 100644 --- a/pubsub/jetstream/jetstream.go +++ b/pubsub/jetstream/jetstream.go @@ -68,7 +68,17 @@ func (js *jetstreamPubSub) Init(metadata pubsub.Metadata) error { } js.l.Debugf("Connected to nats at %s", js.meta.natsURL) - js.jsc, err = js.nc.JetStream() + jsOpts := []nats.JSOpt{} + + if js.meta.domain != "" { + jsOpts = append(jsOpts, nats.Domain(js.meta.domain)) + } + + if js.meta.apiPrefix != "" { + jsOpts = append(jsOpts, nats.APIPrefix(js.meta.apiPrefix)) + } + + js.jsc, err = js.nc.JetStream(jsOpts...) if err != nil { return err } diff --git a/pubsub/jetstream/metadata.go b/pubsub/jetstream/metadata.go index c7cf78288..c11408a40 100644 --- a/pubsub/jetstream/metadata.go +++ b/pubsub/jetstream/metadata.go @@ -51,6 +51,8 @@ type metadata struct { hearbeat time.Duration deliverPolicy nats.DeliverPolicy ackPolicy nats.AckPolicy + domain string + apiPrefix string } func parseMetadata(psm pubsub.Metadata) (metadata, error) { @@ -143,6 +145,13 @@ func parseMetadata(psm pubsub.Metadata) (metadata, error) { m.hearbeat = v } + if domain := psm.Properties["domain"]; domain != "" { + m.domain = domain + } + if apiPrefix := psm.Properties["apiPrefix"]; apiPrefix != "" { + m.apiPrefix = apiPrefix + } + deliverPolicy := psm.Properties["deliverPolicy"] switch deliverPolicy { case "all", "": diff --git a/pubsub/jetstream/metadata_test.go b/pubsub/jetstream/metadata_test.go index 65a2fab52..dedce4b86 100644 --- a/pubsub/jetstream/metadata_test.go +++ b/pubsub/jetstream/metadata_test.go @@ -50,6 +50,7 @@ func TestParseMetadata(t *testing.T) { "memoryStorage": "true", "rateLimit": "20000", "hearbeat": "1s", + "domain": "hub", }, }}, want: metadata{ @@ -70,6 +71,7 @@ func TestParseMetadata(t *testing.T) { hearbeat: time.Second * 1, deliverPolicy: nats.DeliverAllPolicy, ackPolicy: nats.AckExplicitPolicy, + domain: "hub", }, expectErr: false, }, @@ -95,6 +97,7 @@ func TestParseMetadata(t *testing.T) { "deliverPolicy": "sequence", "startSequence": "5", "ackPolicy": "all", + "apiPrefix": "HUB", }, }}, want: metadata{ @@ -116,6 +119,7 @@ func TestParseMetadata(t *testing.T) { token: "myToken", deliverPolicy: nats.DeliverByStartSequencePolicy, ackPolicy: nats.AckAllPolicy, + apiPrefix: "HUB", }, expectErr: false, }, diff --git a/pubsub/mqtt/mqtt.go b/pubsub/mqtt/mqtt.go index 4ef85293f..2a480a811 100644 --- a/pubsub/mqtt/mqtt.go +++ b/pubsub/mqtt/mqtt.go @@ -111,7 +111,16 @@ func (m *mqttPubSub) Publish(_ context.Context, req *pubsub.PublishRequest) erro // m.logger.Debugf("mqtt publishing topic %s with data: %v", req.Topic, req.Data) m.logger.Debugf("mqtt publishing topic %s", req.Topic) - token := m.producer.Publish(req.Topic, m.metadata.qos, m.metadata.retain, req.Data) + retain := m.metadata.retain + if val, ok := req.Metadata[mqttRetain]; ok && val != "" { + var err error + retain, err = strconv.ParseBool(val) + if err != nil { + return fmt.Errorf("mqtt invalid retain %s, %s", val, err) + } + } + + token := m.producer.Publish(req.Topic, m.metadata.qos, retain, req.Data) t := time.NewTimer(defaultWait) defer func() { if !t.Stop() { diff --git a/pubsub/mqtt/mqtt_test.go b/pubsub/mqtt/mqtt_test.go index 4d2225132..b0a611a41 100644 --- a/pubsub/mqtt/mqtt_test.go +++ b/pubsub/mqtt/mqtt_test.go @@ -14,11 +14,20 @@ limitations under the License. package mqtt import ( + "context" "crypto/x509" "encoding/pem" "errors" + "fmt" + "math" + "math/rand" + "reflect" "regexp" + "sync" "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/stretchr/testify/assert" @@ -27,6 +36,176 @@ import ( "github.com/dapr/kit/logger" ) +type mqttMessage struct { + data []byte + retained bool + topic string + qos byte +} + +var _ mqtt.Message = (*mqttMessage)(nil) + +func (m mqttMessage) Duplicate() bool { + return false +} + +func (m mqttMessage) Qos() byte { + return m.qos +} + +func (m mqttMessage) Retained() bool { + return m.retained +} + +func (m mqttMessage) Topic() string { + return m.topic +} + +func (m mqttMessage) MessageID() uint16 { + return uint16(rand.Intn(math.MaxUint16 + 1)) //nolint:gosec +} + +func (m mqttMessage) Payload() []byte { + return m.data +} + +func (m mqttMessage) Ack() { + return +} + +type mockedMQTTToken struct { + m sync.RWMutex + complete chan struct{} + err error +} + +var _ mqtt.Token = (*mockedMQTTToken)(nil) + +func (t *mockedMQTTToken) Wait() bool { + <-t.complete + return true +} + +func (t *mockedMQTTToken) WaitTimeout(d time.Duration) bool { + timer := time.NewTimer(d) + select { + case <-t.complete: + if !timer.Stop() { + <-timer.C + } + return true + case <-timer.C: + } + + return false +} + +func (t *mockedMQTTToken) Done() <-chan struct{} { + return t.complete +} + +func (t *mockedMQTTToken) flowComplete() { + select { + case <-t.complete: + default: + close(t.complete) + } +} + +func (t *mockedMQTTToken) Error() error { + t.m.RLock() + defer t.m.RUnlock() + return t.err +} + +type mockedMQTTClient struct { + msgCh chan mqttMessage +} + +var _ mqtt.Client = (*mockedMQTTClient)(nil) + +func newMockedMQTTClient(ch chan mqttMessage) *mockedMQTTClient { + return &mockedMQTTClient{ + msgCh: ch, + } +} + +func (m mockedMQTTClient) IsConnected() bool { + return true +} + +func (m mockedMQTTClient) IsConnectionOpen() bool { + return true +} + +func (m mockedMQTTClient) Connect() mqtt.Token { + token := &mockedMQTTToken{complete: make(chan struct{})} + token.flowComplete() + + return token +} + +func (m mockedMQTTClient) Disconnect(quiesce uint) { + return +} + +func (m mockedMQTTClient) Publish(topic string, qos byte, retained bool, payload interface{}) mqtt.Token { + token := &mockedMQTTToken{complete: make(chan struct{})} + + msg := mqttMessage{ + data: payload.([]byte), + retained: retained, + topic: topic, + qos: qos, + } + m.msgCh <- msg + + token.flowComplete() + + return token +} + +func (m mockedMQTTClient) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token { + token := &mockedMQTTToken{complete: make(chan struct{})} + token.flowComplete() + + go func() { + for msg := range m.msgCh { + callback(m, msg) + } + }() + + return token +} + +func (m mockedMQTTClient) SubscribeMultiple(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + token := &mockedMQTTToken{complete: make(chan struct{})} + token.flowComplete() + + go func() { + for msg := range m.msgCh { + callback(m, msg) + } + }() + + return token +} + +func (m mockedMQTTClient) Unsubscribe(topics ...string) mqtt.Token { + token := &mockedMQTTToken{complete: make(chan struct{})} + token.flowComplete() + + return token +} + +func (m mockedMQTTClient) AddRoute(topic string, callback mqtt.MessageHandler) { + return +} + +func (m mockedMQTTClient) OptionsReader() mqtt.ClientOptionsReader { + return mqtt.ClientOptionsReader{} +} + func getFakeProperties() map[string]string { return map[string]string{ "consumerID": "client", @@ -456,3 +635,96 @@ func Test_buildRegexForTopic(t *testing.T) { }) } } + +func Test_mqttPubSub_Publish(t *testing.T) { + type fields struct { + logger logger.Logger + metadata *metadata + ctx context.Context + } + type args struct { + req *pubsub.PublishRequest + } + tests := []struct { + name string + fields fields + args args + wantErr assert.ErrorAssertionFunc + wantedMsg mqttMessage + }{ + { + name: "publish request does not contain retain metadata", + fields: fields{ + logger: logger.NewLogger("mqtt-test"), + ctx: context.Background(), + metadata: &metadata{ + retain: true, + }, + }, + args: args{ + req: &pubsub.PublishRequest{ + Data: []byte("test"), + PubsubName: "mqtt", + Metadata: map[string]string{}, + Topic: "test", + ContentType: nil, + }, + }, + wantErr: assert.NoError, + wantedMsg: mqttMessage{ + data: []byte("test"), + retained: true, + topic: "test", + qos: 0, + }, + }, + { + name: "publish request contains retain metadata", + fields: fields{ + logger: logger.NewLogger("mqtt-test"), + ctx: context.Background(), + metadata: &metadata{ + retain: true, + }, + }, + args: args{ + req: &pubsub.PublishRequest{ + Data: []byte("test"), + PubsubName: "mqtt", + Metadata: map[string]string{"retain": "false"}, + Topic: "test", + ContentType: nil, + }, + }, + wantErr: assert.NoError, + wantedMsg: mqttMessage{ + data: []byte("test"), + retained: false, + topic: "test", + qos: 0, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msgCh := make(chan mqttMessage, 1) + + m := &mqttPubSub{ + producer: newMockedMQTTClient(msgCh), + logger: tt.fields.logger, + ctx: tt.fields.ctx, + metadata: tt.fields.metadata, + } + + ctx := context.Background() + tt.wantErr(t, m.Publish(ctx, tt.args.req), fmt.Sprintf("Publish(%v, %v)", ctx, tt.args.req)) + close(msgCh) + + for msg := range msgCh { + if !reflect.DeepEqual(msg, tt.wantedMsg) { + t.Errorf("received different message than expected, got = %v, want %v", m, tt.wantedMsg) + } + } + }) + } +} diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 9bf829847..ca1521d79 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -51,6 +51,10 @@ type dynamoDBMetadata struct { TTLAttributeName string `json:"ttlAttributeName"` } +const ( + metadataPartitionKey = "partitionKey" +) + // NewDynamoDBStateStore returns a new dynamoDB state store. func NewDynamoDBStateStore(_ logger.Logger) state.Store { return &StateStore{} @@ -87,7 +91,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get TableName: aws.String(d.table), Key: map[string]*dynamodb.AttributeValue{ "key": { - S: aws.String(req.Key), + S: aws.String(populatePartitionMetadata(req.Key, req.Metadata)), }, }, } @@ -224,7 +228,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error input := &dynamodb.DeleteItemInput{ Key: map[string]*dynamodb.AttributeValue{ "key": { - S: aws.String(req.Key), + S: aws.String(populatePartitionMetadata(req.Key, req.Metadata)), }, }, TableName: aws.String(d.table), @@ -268,7 +272,7 @@ func (d *StateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest) DeleteRequest: &dynamodb.DeleteRequest{ Key: map[string]*dynamodb.AttributeValue{ "key": { - S: aws.String(r.Key), + S: aws.String(populatePartitionMetadata(r.Key, r.Metadata)), }, }, }, @@ -314,9 +318,10 @@ func (d *StateStore) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, // getItemFromReq converts a dapr state.SetRequest into an dynamodb item func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) { + partitionKey := populatePartitionMetadata(req.Key, req.Metadata) value, err := d.marshalToString(req.Value) if err != nil { - return nil, fmt.Errorf("dynamodb error: failed to set key %s: %s", req.Key, err) + return nil, fmt.Errorf("dynamodb error: failed to set key %s: %s", partitionKey, err) } ttl, err := d.parseTTL(req) @@ -328,9 +333,10 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb if err != nil { return nil, fmt.Errorf("dynamodb error: failed to generate etag: %w", err) } + item := map[string]*dynamodb.AttributeValue{ "key": { - S: aws.String(req.Key), + S: aws.String(partitionKey), }, "value": { S: aws.String(value), @@ -385,3 +391,13 @@ func (d *StateStore) parseTTL(req *state.SetRequest) (*int64, error) { return nil, nil } + +// This is a helper to return the partition key to use. If if metadata["partitionkey"] is present, +// use that, otherwise use what's in "key". +func populatePartitionMetadata(key string, requestMetadata map[string]string) string { + if val, found := requestMetadata[metadataPartitionKey]; found { + return val + } + + return key +} diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index 815d454e0..426c869d0 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -46,6 +46,11 @@ type DynamoDBItem struct { TestAttributeName int64 `json:"testAttributeName"` } +const ( + tableName = "table_name" + pkey = "partitionKey" +) + func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return m.GetItemWithContextFn(ctx, input, op...) } @@ -268,6 +273,45 @@ func TestGet(t *testing.T) { assert.Nil(t, err) assert.Empty(t, out.Data) }) + t.Run("Successfully retrieve item with metadata partition key", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + if *input.Key["key"].S != pkey { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, + }, nil + } + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: input.Key["key"].S, + }, + "value": { + S: aws.String("some value"), + }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil + }, + }, + } + req := &state.GetRequest{ + Key: "someKey", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + Options: state.GetStateOption{ + Consistency: "strong", + }, + } + out, err := ss.Get(context.Background(), req) + assert.Nil(t, err) + assert.Equal(t, []byte("some value"), out.Data) + assert.Equal(t, "1bdead4badc0ffee", *out.ETag) + }) } func TestSet(t *testing.T) { @@ -619,6 +663,40 @@ func TestSet(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) + t.Run("Successfully set item with metadata partition key", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(pkey), + }, *input.Item["key"]) + assert.Equal(t, dynamodb.AttributeValue{ + S: aws.String(`{"Value":"value"}`), + }, *input.Item["value"]) + assert.Equal(t, len(input.Item), 3) + + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), + }, + }, + }, nil + }, + }, + } + req := &state.SetRequest{ + Key: "key", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + Value: value{ + Value: "value", + }, + } + err := ss.Set(context.Background(), req) + assert.Nil(t, err) + }) } func TestBulkSet(t *testing.T) { @@ -627,7 +705,6 @@ func TestBulkSet(t *testing.T) { } t.Run("Successfully set items", func(t *testing.T) { - tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { @@ -694,7 +771,6 @@ func TestBulkSet(t *testing.T) { assert.Nil(t, err) }) t.Run("Successfully set items with ttl = -1", func(t *testing.T) { - tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { @@ -767,7 +843,6 @@ func TestBulkSet(t *testing.T) { assert.Nil(t, err) }) t.Run("Successfully set items with ttl", func(t *testing.T) { - tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { @@ -866,6 +941,78 @@ func TestBulkSet(t *testing.T) { err := ss.BulkSet(context.Background(), req) assert.NotNil(t, err) }) + t.Run("Successfully set items with metadata partition key", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { + expected := map[string][]*dynamodb.WriteRequest{} + expected[tableName] = []*dynamodb.WriteRequest{ + { + PutRequest: &dynamodb.PutRequest{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(pkey), + }, + "value": { + S: aws.String(`{"Value":"value1"}`), + }, + }, + }, + }, + { + PutRequest: &dynamodb.PutRequest{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(pkey), + }, + "value": { + S: aws.String(`{"Value":"value2"}`), + }, + }, + }, + }, + } + + for tbl := range expected { + for reqNum := range expected[tbl] { + expectedItem := expected[tbl][reqNum].PutRequest.Item + inputItem := input.RequestItems[tbl][reqNum].PutRequest.Item + + assert.Equal(t, expectedItem["key"], inputItem["key"]) + assert.Equal(t, expectedItem["value"], inputItem["value"]) + } + } + + return &dynamodb.BatchWriteItemOutput{ + UnprocessedItems: map[string][]*dynamodb.WriteRequest{}, + }, nil + }, + }, + table: tableName, + } + req := []state.SetRequest{ + { + Key: "key1", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + Value: value{ + Value: "value1", + }, + }, + { + Key: "key2", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + Value: value{ + Value: "value2", + }, + }, + } + err := ss.BulkSet(context.Background(), req) + assert.Nil(t, err) + }) } func TestDelete(t *testing.T) { @@ -968,11 +1115,35 @@ func TestDelete(t *testing.T) { err := ss.Delete(context.Background(), req) assert.NotNil(t, err) }) + + t.Run("Successfully delete item with metadata partition key", func(t *testing.T) { + req := &state.DeleteRequest{ + Key: "key", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + } + + ss := StateStore{ + client: &mockedDynamoDB{ + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(pkey), + }, + }, input.Key) + + return nil, nil + }, + }, + } + err := ss.Delete(context.Background(), req) + assert.Nil(t, err) + }) } func TestBulkDelete(t *testing.T) { t.Run("Successfully delete items", func(t *testing.T) { - tableName := "table_name" ss := StateStore{ client: &mockedDynamoDB{ BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { @@ -1036,4 +1207,55 @@ func TestBulkDelete(t *testing.T) { err := ss.BulkDelete(context.Background(), req) assert.NotNil(t, err) }) + t.Run("Successfully delete items with metadata partition key", func(t *testing.T) { + ss := StateStore{ + client: &mockedDynamoDB{ + BatchWriteItemWithContextFn: func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (output *dynamodb.BatchWriteItemOutput, err error) { + expected := map[string][]*dynamodb.WriteRequest{} + expected[tableName] = []*dynamodb.WriteRequest{ + { + DeleteRequest: &dynamodb.DeleteRequest{ + Key: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(pkey), + }, + }, + }, + }, + { + DeleteRequest: &dynamodb.DeleteRequest{ + Key: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String(pkey), + }, + }, + }, + }, + } + assert.Equal(t, expected, input.RequestItems) + + return &dynamodb.BatchWriteItemOutput{ + UnprocessedItems: map[string][]*dynamodb.WriteRequest{}, + }, nil + }, + }, + table: tableName, + } + req := []state.DeleteRequest{ + { + Key: "key1", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + }, + { + Key: "key2", + Metadata: map[string]string{ + metadataPartitionKey: pkey, + }, + }, + } + err := ss.BulkDelete(context.Background(), req) + assert.Nil(t, err) + }) } diff --git a/tests/certification/pubsub/azure/eventhubs/send-iot-device-events.sh b/tests/certification/pubsub/azure/eventhubs/send-iot-device-events.sh index 432266e01..13662b2fc 100755 --- a/tests/certification/pubsub/azure/eventhubs/send-iot-device-events.sh +++ b/tests/certification/pubsub/azure/eventhubs/send-iot-device-events.sh @@ -41,4 +41,4 @@ fi # Send the test IoT device messages to the IoT Hub.`testmessageForEventHubCertificationTest` is being asserted in the certification test # TODO : read messageCount and data as an argument -az iot device simulate -n ${AzureIotHubName} -d ${IOT_HUB_TEST_DEVICE_NAME} --data 'testmessageForEventHubCertificationTest' --msg-count 10 --msg-interval 1 --protocol http --properties "iothub-userid=dapr-user-id;iothub-messageid=dapr-message-id" +az iot device simulate -n ${AzureIotHubName} -d ${IOT_HUB_TEST_DEVICE_NAME} --data 'testmessageForEventHubCertificationTest' --msg-count 10 --msg-interval 1 --protocol http --properties "iothub-userid=dapr-user-id;iothub-messageid=dapr-message-id;applicationProperty=custom-value" diff --git a/tests/scripts/send-iot-device-events.sh b/tests/scripts/send-iot-device-events.sh index 50ba53491..4b500455d 100644 --- a/tests/scripts/send-iot-device-events.sh +++ b/tests/scripts/send-iot-device-events.sh @@ -38,4 +38,4 @@ if [[ -z "$(az iot hub device-identity show -n ${IOT_HUB_NAME} -d ${IOT_HUB_TEST fi # Send the test IoT device messages to the IoT Hub -az iot device simulate -n ${IOT_HUB_NAME} -d ${IOT_HUB_TEST_DEVICE_NAME} --data '{ "data": "Integration test message" }' --msg-count 2 --msg-interval 1 --protocol http --properties "iothub-userid=dapr-user-id;iothub-messageid=dapr-message-id" +az iot device simulate -n ${IOT_HUB_NAME} -d ${IOT_HUB_TEST_DEVICE_NAME} --data '{ "data": "Integration test message" }' --msg-count 2 --msg-interval 1 --protocol http --properties "iothub-userid=dapr-user-id;iothub-messageid=dapr-message-id;applicationProperty=custom-value"