Refactor Metadata Parsing of all PubSub Components (#2759)

Signed-off-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
Bernd Verst 2023-04-09 13:21:56 -05:00 committed by GitHub
parent 4f94da95cf
commit 2b89d78a2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 1197 additions and 1730 deletions

View File

@ -118,6 +118,11 @@ func GenerateMetadataAnalyzer(contribRoot string, componentFolders []string, out
if methodFinderErr == nil {
methodFound = true
}
case "pubsub":
method, methodFinderErr = getConstructorMethod("pubsub.PubSub", parsedFile)
if methodFinderErr == nil {
methodFound = true
}
}
if methodFound {

View File

@ -205,6 +205,6 @@ func (a *AzureServiceBusQueues) GetComponentMetadata() map[string]string {
metadataStruct := impl.Metadata{}
metadataInfo := map[string]string{}
contribMetadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, contribMetadata.BindingType)
delete(metadataInfo, "ConsumerID") // only applies to topics, not queues
delete(metadataInfo, "consumerID") // only applies to topics, not queues
return metadataInfo
}

View File

@ -16,12 +16,10 @@ package servicebus
import (
"errors"
"fmt"
"strconv"
"time"
sbadmin "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
"github.com/dapr/components-contrib/internal/utils"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
@ -31,27 +29,27 @@ import (
// Note: AzureAD-related keys are handled separately.
type Metadata struct {
/** For bindings and pubsubs **/
ConnectionString string `json:"connectionString"`
ConsumerID string `json:"consumerID"` // Only topics
TimeoutInSec int `json:"timeoutInSec"`
HandlerTimeoutInSec int `json:"handlerTimeoutInSec"`
LockRenewalInSec int `json:"lockRenewalInSec"`
MaxActiveMessages int `json:"maxActiveMessages"`
MaxConnectionRecoveryInSec int `json:"maxConnectionRecoveryInSec"`
MinConnectionRecoveryInSec int `json:"minConnectionRecoveryInSec"`
DisableEntityManagement bool `json:"disableEntityManagement"`
MaxRetriableErrorsPerSec int `json:"maxRetriableErrorsPerSec"`
MaxDeliveryCount *int32 `json:"maxDeliveryCount"` // Only used during subscription creation - default is set by the server (10)
LockDurationInSec *int `json:"lockDurationInSec"` // Only used during subscription creation - default is set by the server (60s)
DefaultMessageTimeToLiveInSec *int `json:"defaultMessageTimeToLiveInSec"` // Only used during subscription creation - default is set by the server (depends on the tier)
AutoDeleteOnIdleInSec *int `json:"autoDeleteOnIdleInSec"` // Only used during subscription creation - default is set by the server (disabled)
MaxConcurrentHandlers int `json:"maxConcurrentHandlers"`
PublishMaxRetries int `json:"publishMaxRetries"`
PublishInitialRetryIntervalInMs int `json:"publishInitialRetryIntervalInMs"`
NamespaceName string `json:"namespaceName"` // Only for Azure AD
ConnectionString string `mapstructure:"connectionString"`
ConsumerID string `mapstructure:"consumerID"` // Only topics
TimeoutInSec int `mapstructure:"timeoutInSec"`
HandlerTimeoutInSec int `mapstructure:"handlerTimeoutInSec"`
LockRenewalInSec int `mapstructure:"lockRenewalInSec"`
MaxActiveMessages int `mapstructure:"maxActiveMessages"`
MaxConnectionRecoveryInSec int `mapstructure:"maxConnectionRecoveryInSec"`
MinConnectionRecoveryInSec int `mapstructure:"minConnectionRecoveryInSec"`
DisableEntityManagement bool `mapstructure:"disableEntityManagement"`
MaxRetriableErrorsPerSec int `mapstructure:"maxRetriableErrorsPerSec"`
MaxDeliveryCount *int32 `mapstructure:"maxDeliveryCount"` // Only used during subscription creation - default is set by the server (10)
LockDurationInSec *int `mapstructure:"lockDurationInSec"` // Only used during subscription creation - default is set by the server (60s)
DefaultMessageTimeToLiveInSec *int `mapstructure:"defaultMessageTimeToLiveInSec"` // Only used during subscription creation - default is set by the server (depends on the tier)
AutoDeleteOnIdleInSec *int `mapstructure:"autoDeleteOnIdleInSec"` // Only used during subscription creation - default is set by the server (disabled)
MaxConcurrentHandlers int `mapstructure:"maxConcurrentHandlers"`
PublishMaxRetries int `mapstructure:"publishMaxRetries"`
PublishInitialRetryIntervalInMs int `mapstructure:"publishInitialRetryIntervalInMs"`
NamespaceName string `mapstructure:"namespaceName"` // Only for Azure AD
/** For bindings only **/
QueueName string `json:"queueName"` // Only queues
QueueName string `mapstructure:"queueName" only:"binding"` // Only queues
}
// Keys.
@ -120,198 +118,89 @@ const (
// ParseMetadata parses metadata keys that are common to all Service Bus components
func ParseMetadata(md map[string]string, logger logger.Logger, mode byte) (m *Metadata, err error) {
m = &Metadata{}
m = &Metadata{
TimeoutInSec: defaultTimeoutInSec,
LockRenewalInSec: defaultLockRenewalInSec,
MaxActiveMessages: defaultMaxActiveMessagesPubSub,
MaxConnectionRecoveryInSec: defaultMaxConnectionRecoveryInSec,
MinConnectionRecoveryInSec: defaultMinConnectionRecoveryInSec,
DisableEntityManagement: defaultDisableEntityManagement,
MaxRetriableErrorsPerSec: defaultMaxRetriableErrorsPerSec,
MaxConcurrentHandlers: defaultMaxConcurrentHandlersPubSub,
PublishMaxRetries: defaultPublishMaxRetries,
PublishInitialRetryIntervalInMs: defaultPublishInitialRetryIntervalInMs,
}
if (mode & MetadataModeBinding) != 0 {
m.HandlerTimeoutInSec = defaultHandlerTimeoutInSecBinding
m.MaxActiveMessages = defaultMaxActiveMessagesBinding
m.MaxConcurrentHandlers = defaultMaxConcurrentHandlersBinding
} else {
m.HandlerTimeoutInSec = defaultHandlerTimeoutInSecPubSub
m.MaxActiveMessages = defaultMaxActiveMessagesPubSub
m.MaxConcurrentHandlers = defaultMaxConcurrentHandlersPubSub
}
// upgrade deprecated metadata keys
if val, ok := md["publishInitialRetryInternalInMs"]; ok && val != "" {
// TODO: Remove in a future Dapr release
logger.Warn("Found deprecated metadata property 'publishInitialRetryInternalInMs'; please use 'publishInitialRetryIntervalInMs'")
md["publishInitialRetryIntervalInMs"] = val
delete(md, "publishInitialRetryInternalInMs")
}
mdErr := mdutils.DecodeMetadata(md, &m)
if mdErr != nil {
return m, mdErr
}
/* Required configuration settings - no defaults. */
if val, ok := md[keyConnectionString]; ok && val != "" {
m.ConnectionString = val
if m.ConnectionString != "" {
// The connection string and the namespace cannot both be present.
if namespace, present := md[keyNamespaceName]; present && namespace != "" {
if m.NamespaceName != "" {
return m, errors.New("connectionString and namespaceName cannot both be specified")
}
} else if val, ok := md[keyNamespaceName]; ok && val != "" {
m.NamespaceName = val
} else {
} else if m.NamespaceName == "" {
return m, errors.New("either one of connection string or namespace name are required")
}
if (mode & MetadataModeTopics) != 0 {
if val, ok := md[keyConsumerID]; ok && val != "" {
m.ConsumerID = val
} else {
if m.ConsumerID == "" {
return m, errors.New("missing consumerID")
}
}
if (mode&MetadataModeBinding) != 0 && (mode&MetadataModeTopics) == 0 {
if val, ok := md[keyQueueName]; ok && val != "" {
m.QueueName = val
} else {
if m.QueueName == "" {
return m, errors.New("missing queueName")
}
}
/* Optional configuration settings - defaults will be set by the client. */
m.TimeoutInSec = defaultTimeoutInSec
if val, ok := md[keyTimeoutInSec]; ok && val != "" {
m.TimeoutInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid timeoutInSec %s: %s", val, err)
}
if m.MaxActiveMessages < 1 {
err = errors.New("must be 1 or greater")
return m, err
}
m.DisableEntityManagement = defaultDisableEntityManagement
if val, ok := md[keyDisableEntityManagement]; ok && val != "" {
m.DisableEntityManagement = utils.IsTruthy(val)
}
if (mode & MetadataModeBinding) != 0 {
m.HandlerTimeoutInSec = defaultHandlerTimeoutInSecBinding
} else {
m.HandlerTimeoutInSec = defaultHandlerTimeoutInSecPubSub
}
if val, ok := md[keyHandlerTimeoutInSec]; ok && val != "" {
m.HandlerTimeoutInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid handlerTimeoutInSec %s: %s", val, err)
}
}
m.LockRenewalInSec = defaultLockRenewalInSec
if val, ok := md[keyLockRenewalInSec]; ok && val != "" {
m.LockRenewalInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid lockRenewalInSec %s: %s", val, err)
}
}
if (mode & MetadataModeBinding) != 0 {
m.MaxActiveMessages = defaultMaxActiveMessagesBinding
} else {
m.MaxActiveMessages = defaultMaxActiveMessagesPubSub
}
if val, ok := md[keyMaxActiveMessages]; ok && val != "" {
m.MaxActiveMessages, err = strconv.Atoi(val)
if err == nil && m.MaxActiveMessages < 1 {
err = errors.New("must be 1 or greater")
}
if err != nil {
return m, fmt.Errorf("invalid maxActiveMessages %s: %s", val, err)
}
}
m.MaxRetriableErrorsPerSec = defaultMaxRetriableErrorsPerSec
if val, ok := md[keyMaxRetriableErrorsPerSec]; ok && val != "" {
m.MaxRetriableErrorsPerSec, err = strconv.Atoi(val)
if err == nil && m.MaxRetriableErrorsPerSec < 0 {
err = errors.New("must not be negative")
}
if err != nil {
return m, fmt.Errorf("invalid maxRetriableErrorsPerSec %s: %s", val, err)
}
}
m.MinConnectionRecoveryInSec = defaultMinConnectionRecoveryInSec
if val, ok := md[keyMinConnectionRecoveryInSec]; ok && val != "" {
m.MinConnectionRecoveryInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid minConnectionRecoveryInSec %s: %s", val, err)
}
}
m.MaxConnectionRecoveryInSec = defaultMaxConnectionRecoveryInSec
if val, ok := md[keyMaxConnectionRecoveryInSec]; ok && val != "" {
m.MaxConnectionRecoveryInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid maxConnectionRecoveryInSec %s: %s", val, err)
}
}
if (mode & MetadataModeBinding) != 0 {
m.MaxConcurrentHandlers = defaultMaxConcurrentHandlersBinding
} else {
m.MaxConcurrentHandlers = defaultMaxConcurrentHandlersPubSub
}
if val, ok := md[keyMaxConcurrentHandlers]; ok && val != "" {
m.MaxConcurrentHandlers, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid maxConcurrentHandlers %s: %s", val, err)
}
}
m.PublishMaxRetries = defaultPublishMaxRetries
if val, ok := md[keyPublishMaxRetries]; ok && val != "" {
m.PublishMaxRetries, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid publishMaxRetries %s: %s", val, err)
}
}
// This metadata property has an alias "publishInitialRetryInternalInMs" because of a typo in a previous version of Dapr
m.PublishInitialRetryIntervalInMs = defaultPublishInitialRetryIntervalInMs
if val, ok := md[keyPublishInitialRetryIntervalInMs]; ok && val != "" {
m.PublishInitialRetryIntervalInMs, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid publishInitialRetryIntervalInMs %s: %s", val, err)
}
} else if val, ok := md["publishInitialRetryInternalInMs"]; ok && val != "" {
// TODO: Remove in a future Dapr release
logger.Warn("Found deprecated metadata property 'publishInitialRetryInternalInMs'; please use 'publishInitialRetryIntervalInMs'")
m.PublishInitialRetryIntervalInMs, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid publishInitialRetryInternalInMs %s: %s", val, err)
}
if m.MaxRetriableErrorsPerSec < 0 {
err = errors.New("must not be negative")
return m, err
}
/* Nullable configuration settings - defaults will be set by the server. */
if val, ok := md[keyMaxDeliveryCount]; ok && val != "" {
var valAsInt int64
valAsInt, err = strconv.ParseInt(val, 10, 32)
if err != nil {
return m, fmt.Errorf("invalid maxDeliveryCount %s: %s", val, err)
if m.DefaultMessageTimeToLiveInSec == nil {
duration, found, ttlErr := mdutils.TryGetTTL(md)
if ttlErr != nil {
return m, fmt.Errorf("invalid %s %s: %s", mdutils.TTLMetadataKey, duration, ttlErr)
}
if found {
m.DefaultMessageTimeToLiveInSec = ptr.Of(int(duration.Seconds()))
}
m.MaxDeliveryCount = ptr.Of(int32(valAsInt))
}
if val, ok := md[keyLockDurationInSec]; ok && val != "" {
var valAsInt int
valAsInt, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid lockDurationInSec %s: %s", val, err)
}
m.LockDurationInSec = &valAsInt
}
if val, ok := md[keyDefaultMessageTimeToLiveInSec]; ok && val != "" {
var valAsInt int
valAsInt, err = strconv.Atoi(val)
if err == nil && valAsInt < 0 {
err = errors.New("must be greater than 0")
}
if err != nil {
return m, fmt.Errorf("invalid defaultMessageTimeToLiveInSec %s: %s", val, err)
}
m.DefaultMessageTimeToLiveInSec = &valAsInt
} else if val, ok := md[mdutils.TTLMetadataKey]; ok && val != "" {
var valAsInt int
valAsInt, err = strconv.Atoi(val)
if err == nil && valAsInt < 0 {
err = errors.New("must be greater than 0")
}
if err != nil {
return m, fmt.Errorf("invalid %s %s: %s", mdutils.TTLMetadataKey, val, err)
}
m.DefaultMessageTimeToLiveInSec = &valAsInt
}
if val, ok := md[keyAutoDeleteOnIdleInSec]; ok && val != "" {
var valAsInt int
valAsInt, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("invalid autoDeleteOnIdleInSecKey %s: %s", val, err)
}
m.AutoDeleteOnIdleInSec = &valAsInt
if m.DefaultMessageTimeToLiveInSec != nil && *m.DefaultMessageTimeToLiveInSec == 0 {
return m, errors.New("defaultMessageTimeToLiveInSec must be greater than 0")
}
return m, nil

View File

@ -227,7 +227,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional timeoutInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyTimeoutInSec] = ""
delete(fakeProperties, keyTimeoutInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -286,7 +286,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional handlerTimeoutInSec pubsub", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyHandlerTimeoutInSec] = ""
delete(fakeProperties, keyHandlerTimeoutInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -309,7 +309,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional lockRenewalInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyLockRenewalInSec] = ""
delete(fakeProperties, keyLockRenewalInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -332,7 +332,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional maxRetriableErrorsPerSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxRetriableErrorsPerSec] = ""
delete(fakeProperties, keyMaxRetriableErrorsPerSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -343,22 +343,12 @@ func TestParseServiceBusMetadata(t *testing.T) {
})
t.Run("invalid optional maxRetriableErrorsPerSec", func(t *testing.T) {
// NaN: Not a Number
fakeProperties := getFakeProperties()
fakeProperties[keyMaxRetriableErrorsPerSec] = invalidNumber
// act.
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
// Negative number
fakeProperties = getFakeProperties()
fakeProperties := getFakeProperties()
fakeProperties[keyMaxRetriableErrorsPerSec] = "-1"
// act.
_, err = ParseMetadata(fakeProperties, nil, 0)
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
@ -366,7 +356,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional maxActiveMessages binding", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxActiveMessages] = ""
delete(fakeProperties, keyMaxActiveMessages)
// act.
m, err := ParseMetadata(fakeProperties, nil, MetadataModeBinding)
@ -378,7 +368,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional maxActiveMessages pubsub", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxActiveMessages] = ""
delete(fakeProperties, keyMaxActiveMessages)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -401,7 +391,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional maxConnectionRecoveryInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxConnectionRecoveryInSec] = ""
delete(fakeProperties, keyMaxConnectionRecoveryInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -424,7 +414,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional minConnectionRecoveryInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMinConnectionRecoveryInSec] = ""
delete(fakeProperties, keyMinConnectionRecoveryInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -447,7 +437,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing optional maxConcurrentHandlers", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxConcurrentHandlers] = ""
delete(fakeProperties, keyMaxConcurrentHandlers)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -470,7 +460,7 @@ func TestParseServiceBusMetadata(t *testing.T) {
t.Run("missing nullable maxDeliveryCount", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxDeliveryCount] = ""
delete(fakeProperties, keyMaxDeliveryCount)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -480,20 +470,9 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Nil(t, err)
})
t.Run("invalid nullable maxDeliveryCount", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyMaxDeliveryCount] = invalidNumber
// act.
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
})
t.Run("missing nullable defaultMessageTimeToLiveInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyDefaultMessageTimeToLiveInSec] = ""
delete(fakeProperties, keyDefaultMessageTimeToLiveInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -503,20 +482,9 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Nil(t, err)
})
t.Run("invalid nullable defaultMessageTimeToLiveInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyDefaultMessageTimeToLiveInSec] = invalidNumber
// act.
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
})
t.Run("missing nullable autoDeleteOnIdleInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyAutoDeleteOnIdleInSec] = ""
delete(fakeProperties, keyAutoDeleteOnIdleInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -526,20 +494,9 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Nil(t, err)
})
t.Run("invalid nullable autoDeleteOnIdleInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyAutoDeleteOnIdleInSec] = invalidNumber
// act.
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
})
t.Run("missing nullable lockDurationInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyLockDurationInSec] = ""
delete(fakeProperties, keyLockDurationInSec)
// act.
m, err := ParseMetadata(fakeProperties, nil, 0)
@ -548,15 +505,4 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Nil(t, m.LockDurationInSec)
assert.Nil(t, err)
})
t.Run("invalid nullable lockDurationInSec", func(t *testing.T) {
fakeProperties := getFakeProperties()
fakeProperties[keyLockDurationInSec] = invalidNumber
// act.
_, err := ParseMetadata(fakeProperties, nil, 0)
// assert.
assert.Error(t, err)
})
}

View File

@ -77,6 +77,8 @@ func toTimeDurationHookFunc() mapstructure.DecodeHookFunc {
}
switch f.Kind() {
case reflect.TypeOf(time.Duration(0)).Kind():
return data.(time.Duration), nil
case reflect.String:
var val time.Duration
if data.(string) != "" {

View File

@ -14,6 +14,7 @@ limitations under the License.
package metadata
import (
"errors"
"fmt"
"math"
"reflect"
@ -153,6 +154,7 @@ func DecodeMetadata(input interface{}, result interface{}) error {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
toTimeDurationArrayHookFunc(),
toTimeDurationHookFunc(),
toTruthyBoolHookFunc(),
toStringArrayHookFunc(),
@ -204,6 +206,49 @@ func toStringArrayHookFunc() mapstructure.DecodeHookFunc {
}
}
func toTimeDurationArrayHookFunc() mapstructure.DecodeHookFunc {
convert := func(input string) ([]time.Duration, error) {
res := make([]time.Duration, 0)
for _, v := range strings.Split(input, ",") {
input := strings.TrimSpace(v)
if input == "" {
continue
}
val, err := time.ParseDuration(input)
if err != nil {
// If we can't parse the duration, try parsing it as int64 seconds
seconds, errParse := strconv.ParseInt(input, 10, 0)
if errParse != nil {
return nil, errors.Join(err, errParse)
}
val = time.Duration(seconds * int64(time.Second))
}
res = append(res, val)
}
return res, nil
}
return func(
f reflect.Type,
t reflect.Type,
data interface{},
) (interface{}, error) {
if f == reflect.TypeOf("") && t == reflect.TypeOf([]time.Duration{}) {
inputArrayString := data.(string)
return convert(inputArrayString)
}
if f == reflect.TypeOf("") && t == reflect.TypeOf(ptr.Of([]time.Duration{})) {
inputArrayString := data.(string)
res, err := convert(inputArrayString)
if err != nil {
return nil, err
}
return ptr.Of(res), nil
}
return data, nil
}
}
type ComponentType string
const (
@ -232,7 +277,15 @@ func GetMetadataInfoFromStructType(t reflect.Type, metadataMap *map[string]strin
for i := 0; i < t.NumField(); i++ {
currentField := t.Field(i)
// fields that are not exported cannot be set via the mapstructure metadata decoding mechanism
if !currentField.IsExported() {
continue
}
mapStructureTag := currentField.Tag.Get("mapstructure")
// we are not exporting this field using the mapstructure tag mechanism
if mapStructureTag == "-" {
continue
}
onlyTag := currentField.Tag.Get("only")
if onlyTag != "" {
include := false

View File

@ -97,17 +97,20 @@ func TestTryGetContentType(t *testing.T) {
func TestMetadataDecode(t *testing.T) {
t.Run("Test metadata decoding", func(t *testing.T) {
type testMetadata struct {
Mystring string `json:"mystring"`
Myduration Duration `json:"myduration"`
Myinteger int `json:"myinteger,string"`
Myfloat64 float64 `json:"myfloat64,string"`
Mybool *bool `json:"mybool,omitempty"`
MyRegularDuration time.Duration `json:"myregularduration"`
MyDurationWithoutUnit time.Duration `json:"mydurationwithoutunit"`
MyRegularDurationEmpty time.Duration `json:"myregulardurationempty"`
Mystring string `mapstructure:"mystring"`
Myduration Duration `mapstructure:"myduration"`
Myinteger int `mapstructure:"myinteger"`
Myfloat64 float64 `mapstructure:"myfloat64"`
Mybool *bool `mapstructure:"mybool"`
MyRegularDuration time.Duration `mapstructure:"myregularduration"`
MyDurationWithoutUnit time.Duration `mapstructure:"mydurationwithoutunit"`
MyRegularDurationEmpty time.Duration `mapstructure:"myregulardurationempty"`
MyDurationArray []time.Duration `mapstructure:"mydurationarray"`
MyDurationArrayPointer *[]time.Duration `mapstructure:"mydurationarraypointer"`
MyDurationArrayPointerEmpty *[]time.Duration `mapstructure:"mydurationarraypointerempty"`
MyRegularDurationDefaultValueUnset time.Duration `json:"myregulardurationdefaultvalueunset"`
MyRegularDurationDefaultValueEmpty time.Duration `json:"myregulardurationdefaultvalueempty"`
MyRegularDurationDefaultValueUnset time.Duration `mapstructure:"myregulardurationdefaultvalueunset"`
MyRegularDurationDefaultValueEmpty time.Duration `mapstructure:"myregulardurationdefaultvalueempty"`
}
var m testMetadata
@ -125,6 +128,9 @@ func TestMetadataDecode(t *testing.T) {
"myregulardurationempty": "",
// Not setting myregulardurationdefaultvalueunset on purpose
"myregulardurationdefaultvalueempty": "",
"mydurationarray": "1s,2s,3s,10",
"mydurationarraypointer": "1s,10,2s,20,3s,30",
"mydurationarraypointerempty": ",",
}
err := DecodeMetadata(testData, &m)
@ -140,6 +146,9 @@ func TestMetadataDecode(t *testing.T) {
assert.Equal(t, time.Duration(0), m.MyRegularDurationEmpty)
assert.Equal(t, time.Hour, m.MyRegularDurationDefaultValueUnset)
assert.Equal(t, time.Duration(0), m.MyRegularDurationDefaultValueEmpty)
assert.Equal(t, []time.Duration{time.Second, time.Second * 2, time.Second * 3, time.Second * 10}, m.MyDurationArray)
assert.Equal(t, []time.Duration{time.Second, time.Second * 10, time.Second * 2, time.Second * 20, time.Second * 3, time.Second * 30}, *m.MyDurationArrayPointer)
assert.Equal(t, []time.Duration{}, *m.MyDurationArrayPointerEmpty)
})
t.Run("Test metadata decode hook for truthy values", func(t *testing.T) {
@ -228,17 +237,20 @@ func TestMetadataStructToStringMap(t *testing.T) {
}
type testMetadata struct {
NestedStruct `mapstructure:",squash"`
Mystring string
Myduration Duration
Myinteger int
Myfloat64 float64
Mybool *bool `json:",omitempty"`
MyRegularDuration time.Duration
SomethingWithCustomName string `mapstructure:"something_with_custom_name"`
PubSubOnlyProperty string `mapstructure:"pubsub_only_property" only:"pubsub"`
BindingOnlyProperty string `mapstructure:"binding_only_property" only:"binding"`
PubSubAndBindingProperty string `mapstructure:"pubsub_and_binding_property" only:"pubsub,binding"`
NestedStruct `mapstructure:",squash"`
Mystring string
Myduration Duration
Myinteger int
Myfloat64 float64
Mybool *bool
MyRegularDuration time.Duration
SomethingWithCustomName string `mapstructure:"something_with_custom_name"`
PubSubOnlyProperty string `mapstructure:"pubsub_only_property" only:"pubsub"`
BindingOnlyProperty string `mapstructure:"binding_only_property" only:"binding"`
PubSubAndBindingProperty string `mapstructure:"pubsub_and_binding_property" only:"pubsub,binding"`
MyDurationArray []time.Duration
NotExportedByMapStructure string `mapstructure:"-"`
notExported string //nolint:structcheck,unused
}
m := testMetadata{}
metadatainfo := map[string]string{}
@ -258,5 +270,8 @@ func TestMetadataStructToStringMap(t *testing.T) {
assert.NotContains(t, metadatainfo, "pubsub_only_property")
assert.Equal(t, "string", metadatainfo["binding_only_property"])
assert.Equal(t, "string", metadatainfo["pubsub_and_binding_property"])
assert.Equal(t, "[]time.Duration", metadatainfo["MyDurationArray"])
assert.NotContains(t, metadatainfo, "NotExportedByMapStructure")
assert.NotContains(t, metadatainfo, "notExported")
})
}

View File

@ -3,9 +3,8 @@ package snssqs
import (
"errors"
"fmt"
"strconv"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/aws/aws-sdk-go/aws/endpoints"
@ -13,74 +12,49 @@ import (
type snsSqsMetadata struct {
// aws endpoint for the component to use.
Endpoint string
Endpoint string `mapstructure:"endpoint"`
// access key to use for accessing sqs/sns.
AccessKey string
AccessKey string `mapstructure:"accessKey"`
// secret key to use for accessing sqs/sns.
SecretKey string
SecretKey string `mapstructure:"secretKey"`
// aws session token to use.
SessionToken string
SessionToken string `mapstructure:"sessionToken"`
// aws region in which SNS/SQS should create resources.
Region string
Region string `mapstructure:"region"`
// aws partition in which SNS/SQS should create resources.
Partition string
internalPartition string `mapstructure:"-"`
// name of the queue for this application. The is provided by the runtime as "consumerID".
sqsQueueName string
SqsQueueName string `mapstructure:"consumerID"`
// name of the dead letter queue for this application.
sqsDeadLettersQueueName string
SqsDeadLettersQueueName string `mapstructure:"sqsDeadLettersQueueName"`
// flag to SNS and SQS FIFO.
fifo bool
Fifo bool `mapstructure:"fifo"`
// a namespace for SNS SQS FIFO to order messages within that group. limits consumer concurrency if set but guarantees that all
// published messages would be ordered by their arrival time to SQS.
// see: https://aws.amazon.com/blogs/compute/solving-complex-ordering-challenges-with-amazon-sqs-fifo-queues/
fifoMessageGroupID string
FifoMessageGroupID string `mapstructure:"fifoMessageGroupID"`
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
messageVisibilityTimeout int64
MessageVisibilityTimeout int64 `mapstructure:"messageVisibilityTimeout"`
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
messageRetryLimit int64
MessageRetryLimit int64 `mapstructure:"messageRetryLimit"`
// upon reaching the messageRetryLimit, disables the default deletion behaviour of the message from the SQS queue, and resetting the message visibilty on SQS
// so that other consumers can try consuming that message.
disableDeleteOnRetryLimit bool
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
DisableDeleteOnRetryLimit bool `mapstructure:"disableDeleteOnRetryLimit"`
// if sqsDeadLettersQueueName is set to a value, then the MessageReceiveLimit defines the number of times a message is received
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
messageReceiveLimit int64
MessageReceiveLimit int64 `mapstructure:"messageReceiveLimit"`
// amount of time to await receipt of a message before making another request. Default: 2.
messageWaitTimeSeconds int64
MessageWaitTimeSeconds int64 `mapstructure:"messageWaitTimeSeconds"`
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
messageMaxNumber int64
MessageMaxNumber int64 `mapstructure:"messageMaxNumber"`
// disable resource provisioning of SNS and SQS.
disableEntityManagement bool
DisableEntityManagement bool `mapstructure:"disableEntityManagement"`
// assets creation timeout.
assetsManagementTimeoutSeconds float64
AssetsManagementTimeoutSeconds float64 `mapstructure:"assetsManagementTimeoutSeconds"`
// aws account ID. internally resolved if not given.
accountID string
AccountID string `mapstructure:"accountID"`
// processing concurrency mode
concurrencyMode pubsub.ConcurrencyMode
}
func parseInt64(input string, propertyName string) (int64, error) {
number, err := strconv.Atoi(input)
if err != nil {
return -1, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return int64(number), nil
}
func parseBool(input string, propertyName string) (bool, error) {
val, err := strconv.ParseBool(input)
if err != nil {
return false, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return val, nil
}
func parseFloat64(input string, propertyName string) (float64, error) {
val, err := strconv.ParseFloat(input, 64)
if err != nil {
return 0, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return val, nil
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
}
func maskLeft(s string) string {
@ -91,51 +65,66 @@ func maskLeft(s string) string {
return string(rs)
}
func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
md := &snsSqsMetadata{}
if err := md.setCredsAndQueueNameConfig(metadata); err != nil {
func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, error) {
md := &snsSqsMetadata{
AssetsManagementTimeoutSeconds: assetsManagementDefaultTimeoutSeconds,
MessageVisibilityTimeout: 10,
MessageRetryLimit: 10,
MessageWaitTimeSeconds: 2,
MessageMaxNumber: 10,
}
upgradeMetadata(&meta)
err := metadata.DecodeMetadata(meta.Properties, md)
if err != nil {
return nil, err
}
props := metadata.Properties
if err := md.setMessageVisibilityTimeout(props); err != nil {
return nil, err
if md.Region != "" {
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), md.Region); ok {
md.internalPartition = partition.ID()
} else {
md.internalPartition = "aws"
}
}
if err := md.setMessageRetryLimit(props); err != nil {
return nil, err
if md.SqsQueueName == "" {
return nil, errors.New("consumerID must be set")
}
if err := md.setDeadlettersQueueConfig(props); err != nil {
return nil, err
if md.MessageVisibilityTimeout < 1 {
return nil, errors.New("messageVisibilityTimeout must be greater than 0")
}
if err := md.setDisableDeleteOnRetryLimit(props); err != nil {
return nil, err
if md.MessageRetryLimit < 2 {
return nil, errors.New("messageRetryLimit must be greater than 1")
}
if err := md.setFifoConfig(props); err != nil {
return nil, err
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
if (md.MessageReceiveLimit > 0 || len(md.SqsDeadLettersQueueName) > 0) && !(md.MessageReceiveLimit > 0 && len(md.SqsDeadLettersQueueName) > 0) {
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
}
if err := md.setMessageWaitTimeSeconds(props); err != nil {
return nil, err
if len(md.SqsDeadLettersQueueName) > 0 && md.DisableDeleteOnRetryLimit {
return nil, errors.New("configuration conflict: 'disableDeleteOnRetryLimit' cannot be set to 'true' when 'sqsDeadLettersQueueName' is set to a value. either remove this configuration or set 'disableDeleteOnRetryLimit' to 'false'")
}
if err := md.setMessageMaxNumber(props); err != nil {
return nil, err
if md.MessageWaitTimeSeconds < 1 {
return nil, errors.New("messageWaitTimeSeconds must be greater than 0")
}
if err := md.setDisableEntityManagement(props); err != nil {
return nil, err
// fifo settings: assign user provided Message Group ID
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
if md.FifoMessageGroupID == "" {
md.FifoMessageGroupID = meta.Properties[pubsub.RuntimeConsumerIDKey]
}
if err := md.setAssetsManagementTimeoutSeconds(props); err != nil {
return nil, err
if md.MessageMaxNumber < 1 {
return nil, errors.New("messageMaxNumber must be greater than 0")
} else if md.MessageMaxNumber > 10 {
return nil, errors.New("messageMaxNumber must be less than or equal to 10")
}
if err := md.setConcurrencyMode(props); err != nil {
if err := md.setConcurrencyMode(meta.Properties); err != nil {
return nil, err
}
@ -149,7 +138,7 @@ func (md *snsSqsMetadata) setConcurrencyMode(props map[string]string) error {
if err != nil {
return err
}
md.concurrencyMode = c
md.ConcurrencyMode = c
return nil
}
@ -163,207 +152,18 @@ func (md *snsSqsMetadata) hideDebugPrintedCredentials() string {
return fmt.Sprintf("%#v\n", mdCopy)
}
func (md *snsSqsMetadata) setCredsAndQueueNameConfig(metadata pubsub.Metadata) error {
if val, ok := mdutils.GetMetadataProperty(metadata.Properties, "Endpoint", "endpoint"); ok {
md.Endpoint = val
func upgradeMetadata(m *pubsub.Metadata) {
upgradeMap := map[string]string{
"Endpoint": "endpoint",
"awsAccountID": "accessKey",
"awsSecret": "secretKey",
"awsRegion": "region",
}
if val, ok := mdutils.GetMetadataProperty(metadata.Properties, "awsAccountID", "accessKey"); ok {
md.AccessKey = val
}
if val, ok := mdutils.GetMetadataProperty(metadata.Properties, "awsSecret", "secretKey"); ok {
md.SecretKey = val
}
if val, ok := metadata.Properties["sessionToken"]; ok {
md.SessionToken = val
}
if val, ok := mdutils.GetMetadataProperty(metadata.Properties, "awsRegion", "region"); ok {
md.Region = val
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), val); ok {
md.Partition = partition.ID()
} else {
md.Partition = "aws"
for oldKey, newKey := range upgradeMap {
if val, ok := m.Properties[oldKey]; ok {
m.Properties[newKey] = val
delete(m.Properties, oldKey)
}
}
if val, ok := metadata.Properties["consumerID"]; ok {
md.sqsQueueName = val
} else {
return errors.New("consumerID must be set")
}
return nil
}
func (md *snsSqsMetadata) setAssetsManagementTimeoutSeconds(props map[string]string) error {
if val, ok := props["assetsManagementTimeoutSeconds"]; ok {
parsed, err := parseFloat64(val, "assetsManagementTimeoutSeconds")
if err != nil {
return err
}
md.assetsManagementTimeoutSeconds = parsed
} else {
md.assetsManagementTimeoutSeconds = assetsManagementDefaultTimeoutSeconds
}
return nil
}
func (md *snsSqsMetadata) setDisableEntityManagement(props map[string]string) error {
if val, ok := props["disableEntityManagement"]; ok {
parsed, err := parseBool(val, "disableEntityManagement")
if err != nil {
return err
}
md.disableEntityManagement = parsed
}
return nil
}
func (md *snsSqsMetadata) setMessageMaxNumber(props map[string]string) error {
if val, ok := props["messageMaxNumber"]; !ok {
md.messageMaxNumber = 10
} else {
maxNumber, err := parseInt64(val, "messageMaxNumber")
if err != nil {
return err
}
if maxNumber < 1 {
return errors.New("messageMaxNumber must be greater than 0")
} else if maxNumber > 10 {
return errors.New("messageMaxNumber must be less than or equal to 10")
}
md.messageMaxNumber = maxNumber
}
return nil
}
func (md *snsSqsMetadata) setMessageWaitTimeSeconds(props map[string]string) error {
if val, ok := props["messageWaitTimeSeconds"]; !ok {
md.messageWaitTimeSeconds = 2
} else {
waitTime, err := parseInt64(val, "messageWaitTimeSeconds")
if err != nil {
return err
}
if waitTime < 1 {
return errors.New("messageWaitTimeSeconds must be greater than 0")
}
md.messageWaitTimeSeconds = waitTime
}
return nil
}
func (md *snsSqsMetadata) setFifoConfig(props map[string]string) error {
// fifo settings: enable/disable SNS and SQS FIFO.
if val, ok := props["fifo"]; ok {
fifo, err := parseBool(val, "fifo")
if err != nil {
return err
}
md.fifo = fifo
} else {
md.fifo = false
}
// fifo settings: assign user provided Message Group ID
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
if val, ok := props["fifoMessageGroupID"]; ok {
md.fifoMessageGroupID = val
} else {
md.fifoMessageGroupID = props[pubsub.RuntimeConsumerIDKey]
}
return nil
}
func (md *snsSqsMetadata) setDeadlettersQueueConfig(props map[string]string) error {
if val, ok := props["sqsDeadLettersQueueName"]; ok {
md.sqsDeadLettersQueueName = val
}
if val, ok := props["messageReceiveLimit"]; ok {
messageReceiveLimit, err := parseInt64(val, "messageReceiveLimit")
if err != nil {
return err
}
// assign: used provided configuration
md.messageReceiveLimit = messageReceiveLimit
}
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
return errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
}
return nil
}
func (md *snsSqsMetadata) setDisableDeleteOnRetryLimit(props map[string]string) error {
if val, ok := props["disableDeleteOnRetryLimit"]; ok {
disableDeleteOnRetryLimit, err := parseBool(val, "disableDeleteOnRetryLimit")
if err != nil {
return err
}
if len(md.sqsDeadLettersQueueName) > 0 && disableDeleteOnRetryLimit {
return errors.New("configuration conflict: 'disableDeleteOnRetryLimit' cannot be set to 'true' when 'sqsDeadLettersQueueName' is set to a value. either remove this configuration or set 'disableDeleteOnRetryLimit' to 'false'")
}
md.disableDeleteOnRetryLimit = disableDeleteOnRetryLimit
} else {
// default when not configured.
md.disableDeleteOnRetryLimit = false
}
return nil
}
func (md *snsSqsMetadata) setMessageRetryLimit(props map[string]string) error {
if val, ok := props["messageRetryLimit"]; !ok {
md.messageRetryLimit = 10
} else {
retryLimit, err := parseInt64(val, "messageRetryLimit")
if err != nil {
return err
}
if retryLimit < 2 {
return errors.New("messageRetryLimit must be greater than 1")
}
md.messageRetryLimit = retryLimit
}
return nil
}
func (md *snsSqsMetadata) setMessageVisibilityTimeout(props map[string]string) error {
if val, ok := props["messageVisibilityTimeout"]; !ok {
md.messageVisibilityTimeout = 10
} else {
timeout, err := parseInt64(val, "messageVisibilityTimeout")
if err != nil {
return err
}
if timeout < 1 {
return errors.New("messageVisibilityTimeout must be greater than 0")
}
md.messageVisibilityTimeout = timeout
}
return nil
}

View File

@ -18,6 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
@ -35,6 +36,7 @@ import (
gonanoid "github.com/matoous/go-nanoid/v2"
awsAuth "github.com/dapr/components-contrib/internal/authentication/aws"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -174,7 +176,7 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
s.sqsClient = sqs.New(sess)
s.stsClient = sts.New(sess)
s.opsTimeout = time.Duration(md.assetsManagementTimeoutSeconds * float64(time.Second))
s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second))
err = s.setAwsAccountIDIfNotProvided(ctx)
if err != nil {
@ -192,7 +194,7 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
}
func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error {
if len(s.metadata.accountID) == awsAccountIDLength {
if len(s.metadata.AccountID) == awsAccountIDLength {
return nil
}
@ -203,22 +205,22 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error {
return fmt.Errorf("error fetching sts caller ID: %w", err)
}
s.metadata.accountID = *callerIDOutput.Account
s.metadata.AccountID = *callerIDOutput.Account
return nil
}
func (s *snsSqs) buildARN(serviceName, entityName string) string {
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.Partition, serviceName, s.metadata.Region, s.metadata.accountID, entityName)
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName)
}
func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) {
sanitizedName := nameToAWSSanitizedName(topic, s.metadata.fifo)
sanitizedName := nameToAWSSanitizedName(topic, s.metadata.Fifo)
snsCreateTopicInput := &sns.CreateTopicInput{
Name: aws.String(sanitizedName),
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
}
if s.metadata.fifo {
if s.metadata.Fifo {
attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
snsCreateTopicInput.SetAttributes(attributes)
}
@ -253,7 +255,7 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s
s.topicsLock.Lock()
defer s.topicsLock.Unlock()
sanitizedName = nameToAWSSanitizedName(topic, s.metadata.fifo)
sanitizedName = nameToAWSSanitizedName(topic, s.metadata.Fifo)
topicArnCached, ok := s.topicArns[sanitizedName]
if ok && topicArnCached != "" {
@ -263,7 +265,7 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s
// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic)
if !s.metadata.disableEntityManagement {
if !s.metadata.DisableEntityManagement {
topicArn, err = s.createTopic(ctx, sanitizedName)
if err != nil {
s.logger.Errorf("error creating new topic %s: %w", topic, err)
@ -286,13 +288,13 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s
}
func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.fifo)
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
sqsCreateQueueInput := &sqs.CreateQueueInput{
QueueName: aws.String(sanitizedName),
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
}
if s.metadata.fifo {
if s.metadata.Fifo {
attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
sqsCreateQueueInput.SetAttributes(attributes)
}
@ -321,7 +323,7 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ
func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)})
queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)})
cancel()
if err != nil {
return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName)
@ -352,9 +354,9 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName)
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.fifo)
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
if !s.metadata.disableEntityManagement {
if !s.metadata.DisableEntityManagement {
queueInfo, err = s.createQueue(ctx, sanitizedName)
if err != nil {
s.logger.Errorf("Error creating queue %s: %v", queueName, err)
@ -377,8 +379,8 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
}
func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string {
if len(s.metadata.fifoMessageGroupID) > 0 {
return &s.metadata.fifoMessageGroupID
if len(s.metadata.FifoMessageGroupID) > 0 {
return &s.metadata.FifoMessageGroupID
}
// each daprd, of a given PubSub, of a given publisher application publishes to a message group ID of its own.
// for example: for a daprd serving the SNS/SQS Pubsub component we generate a unique id -> A; that component serves on behalf
@ -435,7 +437,7 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to
s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
if !s.metadata.disableEntityManagement {
if !s.metadata.DisableEntityManagement {
subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn)
if err != nil {
s.logger.Errorf("Error creating subscription %s: %v", subscriptionArn, err)
@ -509,10 +511,10 @@ func (s *snsSqs) validateMessage(ctx context.Context, message *sqs.Message, queu
return err
}
messageRetryLimit := s.metadata.messageRetryLimit
messageRetryLimit := s.metadata.MessageRetryLimit
if deadLettersQueueInfo == nil && recvCount >= messageRetryLimit {
// if we are over the allowable retry limit, and there is no dead-letters queue, and we don't disable deletes, then delete the message from the queue.
if !s.metadata.disableDeleteOnRetryLimit {
if !s.metadata.DisableDeleteOnRetryLimit {
if innerErr := s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle); innerErr != nil {
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
}
@ -533,7 +535,7 @@ func (s *snsSqs) validateMessage(ctx context.Context, message *sqs.Message, queu
// a message if we've already surpassed the messageRetryLimit value.
if deadLettersQueueInfo != nil && recvCount > messageRetryLimit {
awsErr := fmt.Errorf(
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", messageRetryLimit, s.metadata.sqsDeadLettersQueueName)
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", messageRetryLimit, s.metadata.SqsDeadLettersQueueName)
return awsErr
}
@ -581,10 +583,10 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
AttributeNames: []*string{
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
},
MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber),
MaxNumberOfMessages: aws.Int64(s.metadata.MessageMaxNumber),
QueueUrl: aws.String(queueInfo.url),
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds),
VisibilityTimeout: aws.Int64(s.metadata.MessageVisibilityTimeout),
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
}
for {
@ -638,7 +640,7 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}
wg.Add(1)
switch s.metadata.concurrencyMode {
switch s.metadata.ConcurrencyMode {
case pubsub.Single:
f(message)
case pubsub.Parallel:
@ -659,7 +661,7 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
policy := map[string]string{
"deadLetterTargetArn": deadLettersQueueInfo.arn,
"maxReceiveCount": strconv.FormatInt(s.metadata.messageReceiveLimit, 10),
"maxReceiveCount": strconv.FormatInt(s.metadata.MessageReceiveLimit, 10),
}
b, err := json.Marshal(policy)
@ -681,7 +683,7 @@ func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInf
}
func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) error {
if s.metadata.disableEntityManagement {
if s.metadata.DisableEntityManagement {
return nil
}
@ -709,7 +711,7 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI
func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, sqsQueueInfo *sqsQueueInfo, snsARN string) error {
// not creating any policies of disableEntityManagement is true.
if s.metadata.disableEntityManagement {
if s.metadata.DisableEntityManagement {
return nil
}
@ -773,7 +775,7 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
// this is the ID of the application, it is supplied via runtime as "consumerID".
var queueInfo *sqsQueueInfo
queueInfo, err = s.getOrCreateQueue(ctx, s.metadata.sqsQueueName)
queueInfo, err = s.getOrCreateQueue(ctx, s.metadata.SqsQueueName)
if err != nil {
wrappedErr := fmt.Errorf("error retrieving SQS queue: %w", err)
s.logger.Error(wrappedErr)
@ -795,8 +797,8 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
var deadLettersQueueInfo *sqsQueueInfo
var derr error
if len(s.metadata.sqsDeadLettersQueueName) > 0 {
deadLettersQueueInfo, derr = s.getOrCreateQueue(ctx, s.metadata.sqsDeadLettersQueueName)
if len(s.metadata.SqsDeadLettersQueueName) > 0 {
deadLettersQueueInfo, derr = s.getOrCreateQueue(ctx, s.metadata.SqsDeadLettersQueueName)
if derr != nil {
wrappedErr := fmt.Errorf("error retrieving SQS dead-letter queue: %w", err)
s.logger.Error(wrappedErr)
@ -898,7 +900,7 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error
Message: aws.String(message),
TopicArn: aws.String(topicArn),
}
if s.metadata.fifo {
if s.metadata.Fifo {
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
}
@ -927,3 +929,11 @@ func (s *snsSqs) Close() error {
func (s *snsSqs) Features() []pubsub.Feature {
return nil
}
// GetComponentMetadata returns the metadata of the component.
func (s *snsSqs) GetComponentMetadata() map[string]string {
metadataStruct := snsSqsMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -65,19 +65,19 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) {
r.NoError(err)
r.Equal("consumer", md.sqsQueueName)
r.Equal("consumer", md.SqsQueueName)
r.Equal("endpoint", md.Endpoint)
r.Equal(pubsub.Single, md.concurrencyMode)
r.Equal(pubsub.Single, md.ConcurrencyMode)
r.Equal("a", md.AccessKey)
r.Equal("s", md.SecretKey)
r.Equal("t", md.SessionToken)
r.Equal("r", md.Region)
r.Equal("q", md.sqsDeadLettersQueueName)
r.Equal(int64(2), md.messageVisibilityTimeout)
r.Equal(int64(3), md.messageRetryLimit)
r.Equal(int64(4), md.messageWaitTimeSeconds)
r.Equal(int64(5), md.messageMaxNumber)
r.Equal(int64(6), md.messageReceiveLimit)
r.Equal("q", md.SqsDeadLettersQueueName)
r.Equal(int64(2), md.MessageVisibilityTimeout)
r.Equal(int64(3), md.MessageRetryLimit)
r.Equal(int64(4), md.MessageWaitTimeSeconds)
r.Equal(int64(5), md.MessageMaxNumber)
r.Equal(int64(6), md.MessageReceiveLimit)
}
func Test_getSnsSqsMetatdata_defaults(t *testing.T) {
@ -98,20 +98,20 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) {
r.NoError(err)
r.Equal("c", md.sqsQueueName)
r.Equal("c", md.SqsQueueName)
r.Equal("", md.Endpoint)
r.Equal("a", md.AccessKey)
r.Equal("s", md.SecretKey)
r.Equal("", md.SessionToken)
r.Equal("r", md.Region)
r.Equal(pubsub.Parallel, md.concurrencyMode)
r.Equal(int64(10), md.messageVisibilityTimeout)
r.Equal(int64(10), md.messageRetryLimit)
r.Equal(int64(2), md.messageWaitTimeSeconds)
r.Equal(int64(10), md.messageMaxNumber)
r.Equal(false, md.disableEntityManagement)
r.Equal(float64(5), md.assetsManagementTimeoutSeconds)
r.Equal(false, md.disableDeleteOnRetryLimit)
r.Equal(pubsub.Parallel, md.ConcurrencyMode)
r.Equal(int64(10), md.MessageVisibilityTimeout)
r.Equal(int64(10), md.MessageRetryLimit)
r.Equal(int64(2), md.MessageWaitTimeSeconds)
r.Equal(int64(10), md.MessageMaxNumber)
r.Equal(false, md.DisableEntityManagement)
r.Equal(float64(5), md.AssetsManagementTimeoutSeconds)
r.Equal(false, md.DisableDeleteOnRetryLimit)
}
func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) {
@ -132,15 +132,15 @@ func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) {
r.NoError(err)
r.Equal("consumer", md.sqsQueueName)
r.Equal("consumer", md.SqsQueueName)
r.Equal("", md.Endpoint)
r.Equal("acctId", md.AccessKey)
r.Equal("secret", md.SecretKey)
r.Equal("region", md.Region)
r.Equal(int64(10), md.messageVisibilityTimeout)
r.Equal(int64(10), md.messageRetryLimit)
r.Equal(int64(2), md.messageWaitTimeSeconds)
r.Equal(int64(10), md.messageMaxNumber)
r.Equal(int64(10), md.MessageVisibilityTimeout)
r.Equal(int64(10), md.MessageRetryLimit)
r.Equal(int64(2), md.MessageWaitTimeSeconds)
r.Equal(int64(10), md.MessageMaxNumber)
}
func testMetadataParsingShouldFail(t *testing.T, metadata pubsub.Metadata, l logger.Logger) {
@ -161,18 +161,6 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
t.Parallel()
fixtures := []testUnitFixture{
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"fifo": "none bool",
}}},
name: "fifo not set to boolean",
},
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
@ -221,7 +209,7 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
"Region": "region",
"messageMaxNumber": "-100",
}}},
name: "illigal message max number (negative, too low)",
name: "illegal message max number (negative, too low)",
},
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
@ -233,7 +221,7 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
"Region": "region",
"messageMaxNumber": "100",
}}},
name: "illigal message max number (too high)",
name: "illegal message max number (too high)",
},
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
@ -271,20 +259,6 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
}}},
name: "invalid message retry limit",
},
// disableEntityManagement
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"messageRetryLimit": "10",
"disableEntityManagement": "y",
}}},
name: "invalid message disableEntityManagement",
},
// invalid concurrencyMode
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
@ -311,24 +285,6 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
}
}
func Test_parseInt64(t *testing.T) {
t.Parallel()
r := require.New(t)
number, err := parseInt64("applesauce", "propertyName")
r.EqualError(err, "parsing propertyName failed with: strconv.Atoi: parsing \"applesauce\": invalid syntax")
r.Equal(int64(-1), number)
number, _ = parseInt64("1000", "")
r.Equal(int64(1000), number)
number, _ = parseInt64("-1000", "")
r.Equal(int64(-1000), number)
// Expecting that this function doesn't panic.
_, err = parseInt64("999999999999999999999999999999999999999999999999999999999999999999999999999", "")
r.Error(err)
}
func Test_replaceNameToAWSSanitizedName(t *testing.T) {
t.Parallel()
r := require.New(t)
@ -483,7 +439,7 @@ func Test_buildARN_DefaultPartition(t *testing.T) {
"region": "r",
}}})
r.NoError(err)
md.accountID = "123456789012"
md.AccountID = "123456789012"
ps.metadata = md
arn := ps.buildARN("sns", "myTopic")
@ -506,7 +462,7 @@ func Test_buildARN_StandardPartition(t *testing.T) {
"region": "us-west-2",
}}})
r.NoError(err)
md.accountID = "123456789012"
md.AccountID = "123456789012"
ps.metadata = md
arn := ps.buildARN("sns", "myTopic")
@ -529,7 +485,7 @@ func Test_buildARN_NonStandardPartition(t *testing.T) {
"region": "cn-northwest-1",
}}})
r.NoError(err)
md.accountID = "123456789012"
md.AccountID = "123456789012"
ps.metadata = md
arn := ps.buildARN("sns", "myTopic")

View File

@ -16,13 +16,14 @@ package eventhubs
import (
"context"
"errors"
"reflect"
"strconv"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs"
impl "github.com/dapr/components-contrib/internal/component/azure/eventhubs"
"github.com/dapr/components-contrib/internal/utils"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
@ -82,7 +83,7 @@ func (aeh *AzureEventHubs) BulkPublish(ctx context.Context, req *pubsub.BulkPubl
// Batch options
batchOpts := &azeventhubs.EventDataBatchOptions{}
if val := req.Metadata[contribMetadata.MaxBulkPubBytesKey]; val != "" {
if val := req.Metadata[metadata.MaxBulkPubBytesKey]; val != "" {
var maxBytes uint64
maxBytes, err = strconv.ParseUint(val, 10, 63)
if err == nil && maxBytes > 0 {
@ -144,3 +145,11 @@ func (aeh *AzureEventHubs) Subscribe(ctx context.Context, req pubsub.SubscribeRe
func (aeh *AzureEventHubs) Close() (err error) {
return aeh.AzureEventHubs.Close()
}
// GetComponentMetadata returns the metadata of the component.
func (aeh *AzureEventHubs) GetComponentMetadata() map[string]string {
metadataStruct := impl.AzureEventHubsMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -17,12 +17,14 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"sync/atomic"
"time"
impl "github.com/dapr/components-contrib/internal/component/azure/servicebus"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -222,3 +224,12 @@ func (a *azureServiceBus) Features() []pubsub.Feature {
pubsub.FeatureMessageTTL,
}
}
// GetComponentMetadata returns the metadata of the component.
func (a *azureServiceBus) GetComponentMetadata() map[string]string {
metadataStruct := impl.Metadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
delete(metadataInfo, "consumerID") // does not apply to queues
return metadataInfo
}

View File

@ -17,12 +17,14 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"sync/atomic"
"time"
impl "github.com/dapr/components-contrib/internal/component/azure/servicebus"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -312,3 +314,11 @@ func (a *azureServiceBus) connectAndReceiveWithSessions(ctx context.Context, req
}()
}
}
// GetComponentMetadata returns the metadata of the component.
func (a *azureServiceBus) GetComponentMetadata() map[string]string {
metadataStruct := impl.Metadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -15,20 +15,20 @@ package pubsub
// GCPPubSubMetaData pubsub metadata.
type metadata struct {
consumerID string
Type string
IdentityProjectID string
ProjectID string
PrivateKeyID string
PrivateKey string
ClientEmail string
ClientID string
AuthURI string
TokenURI string
AuthProviderCertURL string
ClientCertURL string
DisableEntityManagement bool
EnableMessageOrdering bool
MaxReconnectionAttempts int
ConnectionRecoveryInSec int
ConsumerID string `mapstructure:"consumerID"`
Type string `mapstructure:"type"`
IdentityProjectID string `mapstructure:"identityProjectID"`
ProjectID string `mapstructure:"projectID"`
PrivateKeyID string `mapstructure:"privateKeyID"`
PrivateKey string `mapstructure:"privateKey"`
ClientEmail string `mapstructure:"clientEmail"`
ClientID string `mapstructure:"clientID"`
AuthURI string `mapstructure:"authURI"`
TokenURI string `mapstructure:"tokenURI"`
AuthProviderCertURL string `mapstructure:"authProviderX509CertUrl"`
ClientCertURL string `mapstructure:"clientX509CertUrl"`
DisableEntityManagement bool `mapstructure:"disableEntityManagement"`
EnableMessageOrdering bool `mapstructure:"enableMessageOrdering"`
MaxReconnectionAttempts int `mapstructure:"maxReconnectionAttempts"`
ConnectionRecoveryInSec int `mapstructure:"connectionRecoveryInSec"`
}

View File

@ -18,7 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"reflect"
"sync"
"sync/atomic"
"time"
@ -28,6 +28,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -96,88 +97,19 @@ func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) {
result := metadata{
DisableEntityManagement: false,
Type: "service_account",
MaxReconnectionAttempts: defaultMaxReconnectionAttempts,
ConnectionRecoveryInSec: defaultConnectionRecoveryInSec,
}
if val, found := pubSubMetadata.Properties[metadataTypeKey]; found && val != "" {
result.Type = val
err := contribMetadata.DecodeMetadata(pubSubMetadata.Properties, &result)
if err != nil {
return nil, err
}
if val, found := pubSubMetadata.Properties[metadataConsumerIDKey]; found && val != "" {
result.consumerID = val
}
if val, found := pubSubMetadata.Properties[metadataIdentityProjectIDKey]; found && val != "" {
result.IdentityProjectID = val
}
if val, found := pubSubMetadata.Properties[metadataProjectIDKey]; found && val != "" {
result.ProjectID = val
} else {
if result.ProjectID == "" {
return &result, fmt.Errorf("%s missing attribute %s", errorMessagePrefix, metadataProjectIDKey)
}
if val, found := pubSubMetadata.Properties[metadataPrivateKeyIDKey]; found && val != "" {
result.PrivateKeyID = val
}
if val, found := pubSubMetadata.Properties[metadataClientEmailKey]; found && val != "" {
result.ClientEmail = val
}
if val, found := pubSubMetadata.Properties[metadataClientIDKey]; found && val != "" {
result.ClientID = val
}
if val, found := pubSubMetadata.Properties[metadataAuthURIKey]; found && val != "" {
result.AuthURI = val
}
if val, found := pubSubMetadata.Properties[metadataTokenURIKey]; found && val != "" {
result.TokenURI = val
}
if val, found := pubSubMetadata.Properties[metadataAuthProviderX509CertURLKey]; found && val != "" {
result.AuthProviderCertURL = val
}
if val, found := pubSubMetadata.Properties[metadataClientX509CertURLKey]; found && val != "" {
result.ClientCertURL = val
}
if val, found := pubSubMetadata.Properties[metadataPrivateKeyKey]; found && val != "" {
result.PrivateKey = val
}
if val, found := pubSubMetadata.Properties[metadataDisableEntityManagementKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.DisableEntityManagement = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataEnableMessageOrderingKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.EnableMessageOrdering = boolVal
}
}
result.MaxReconnectionAttempts = defaultMaxReconnectionAttempts
if val, ok := pubSubMetadata.Properties[metadataMaxReconnectionAttemptsKey]; ok && val != "" {
var err error
result.MaxReconnectionAttempts, err = strconv.Atoi(val)
if err != nil {
return &result, fmt.Errorf("%s invalid maxReconnectionAttempts %s, %s", errorMessagePrefix, val, err)
}
}
result.ConnectionRecoveryInSec = defaultConnectionRecoveryInSec
if val, ok := pubSubMetadata.Properties[metadataConnectionRecoveryInSecKey]; ok && val != "" {
var err error
result.ConnectionRecoveryInSec, err = strconv.Atoi(val)
if err != nil {
return &result, fmt.Errorf("%s invalid connectionRecoveryInSec %s, %s", errorMessagePrefix, val, err)
}
}
return &result, nil
}
@ -269,14 +201,14 @@ func (g *GCPPubSub) Subscribe(parentCtx context.Context, req pubsub.SubscribeReq
return fmt.Errorf("%s could not get valid topic %s, %s", errorMessagePrefix, req.Topic, topicErr)
}
subError := g.ensureSubscription(parentCtx, g.metadata.consumerID, req.Topic)
subError := g.ensureSubscription(parentCtx, g.metadata.ConsumerID, req.Topic)
if subError != nil {
return fmt.Errorf("%s could not get valid subscription %s, %s", errorMessagePrefix, g.metadata.consumerID, subError)
return fmt.Errorf("%s could not get valid subscription %s, %s", errorMessagePrefix, g.metadata.ConsumerID, subError)
}
}
topic := g.getTopic(req.Topic)
sub := g.getSubscription(g.metadata.consumerID + "-" + req.Topic)
sub := g.getSubscription(g.metadata.ConsumerID + "-" + req.Topic)
subscribeCtx, cancel := context.WithCancel(parentCtx)
g.wg.Add(2)
@ -435,3 +367,11 @@ func (g *GCPPubSub) Close() error {
func (g *GCPPubSub) Features() []pubsub.Feature {
return nil
}
// GetComponentMetadata returns the metadata of the component.
func (g *GCPPubSub) GetComponentMetadata() map[string]string {
metadataStruct := metadata{}
metadataInfo := map[string]string{}
contribMetadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, contribMetadata.PubSubType)
return metadataInfo
}

View File

@ -83,7 +83,6 @@ func TestInit(t *testing.T) {
m.Properties = map[string]string{
"projectId": "superproject",
}
m.Properties[metadataMaxReconnectionAttemptsKey] = ""
pubSubMetadata, err := createMetadata(m)
@ -101,7 +100,7 @@ func TestInit(t *testing.T) {
_, err := createMetadata(m)
assert.Error(t, err)
assertValidErrorMessage(t, err)
assert.ErrorContains(t, err, "maxReconnectionAttempts")
})
t.Run("missing optional connectionRecoveryInSec", func(t *testing.T) {
@ -109,7 +108,6 @@ func TestInit(t *testing.T) {
m.Properties = map[string]string{
"projectId": "superproject",
}
m.Properties[metadataConnectionRecoveryInSecKey] = ""
pubSubMetadata, err := createMetadata(m)
@ -127,10 +125,6 @@ func TestInit(t *testing.T) {
_, err := createMetadata(m)
assert.Error(t, err)
assertValidErrorMessage(t, err)
assert.ErrorContains(t, err, "connectionRecoveryInSec")
})
}
func assertValidErrorMessage(t *testing.T, err error) {
assert.Contains(t, err.Error(), errorMessagePrefix)
}

View File

@ -110,3 +110,8 @@ func (a *bus) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handle
return nil
}
// GetComponentMetadata returns the metadata of the component.
func (a *bus) GetComponentMetadata() map[string]string {
return map[string]string{}
}

View File

@ -16,12 +16,14 @@ package jetstream
import (
"context"
"errors"
"reflect"
"sync"
"sync/atomic"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/retry"
@ -55,37 +57,37 @@ func (js *jetstreamPubSub) Init(_ context.Context, metadata pubsub.Metadata) err
}
var opts []nats.Option
opts = append(opts, nats.Name(js.meta.name))
opts = append(opts, nats.Name(js.meta.Name))
// Set nats.UserJWT options when jwt and seed key is provided.
if js.meta.jwt != "" && js.meta.seedKey != "" {
if js.meta.Jwt != "" && js.meta.SeedKey != "" {
opts = append(opts, nats.UserJWT(func() (string, error) {
return js.meta.jwt, nil
return js.meta.Jwt, nil
}, func(nonce []byte) ([]byte, error) {
return sigHandler(js.meta.seedKey, nonce)
return sigHandler(js.meta.SeedKey, nonce)
}))
} else if js.meta.tlsClientCert != "" && js.meta.tlsClientKey != "" {
} else if js.meta.TLSClientCert != "" && js.meta.TLSClientKey != "" {
js.l.Debug("Configure nats for tls client authentication")
opts = append(opts, nats.ClientCert(js.meta.tlsClientCert, js.meta.tlsClientKey))
} else if js.meta.token != "" {
opts = append(opts, nats.ClientCert(js.meta.TLSClientCert, js.meta.TLSClientKey))
} else if js.meta.Token != "" {
js.l.Debug("Configure nats for token authentication")
opts = append(opts, nats.Token(js.meta.token))
opts = append(opts, nats.Token(js.meta.Token))
}
js.nc, err = nats.Connect(js.meta.natsURL, opts...)
js.nc, err = nats.Connect(js.meta.NatsURL, opts...)
if err != nil {
return err
}
js.l.Debugf("Connected to nats at %s", js.meta.natsURL)
js.l.Debugf("Connected to nats at %s", js.meta.NatsURL)
jsOpts := []nats.JSOpt{}
if js.meta.domain != "" {
jsOpts = append(jsOpts, nats.Domain(js.meta.domain))
if js.meta.Domain != "" {
jsOpts = append(jsOpts, nats.Domain(js.meta.Domain))
}
if js.meta.apiPrefix != "" {
jsOpts = append(jsOpts, nats.APIPrefix(js.meta.apiPrefix))
if js.meta.APIPrefix != "" {
jsOpts = append(jsOpts, nats.APIPrefix(js.meta.APIPrefix))
}
js.jsc, err = js.nc.JetStream(jsOpts...)
@ -148,46 +150,46 @@ func (js *jetstreamPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRe
consumerConfig.DeliverSubject = nats.NewInbox()
if v := js.meta.durableName; v != "" {
if v := js.meta.DurableName; v != "" {
consumerConfig.Durable = v
}
if v := js.meta.startTime; !v.IsZero() {
if v := js.meta.internalStartTime; !v.IsZero() {
consumerConfig.OptStartTime = &v
}
if v := js.meta.startSequence; v > 0 {
if v := js.meta.StartSequence; v > 0 {
consumerConfig.OptStartSeq = v
}
consumerConfig.DeliverPolicy = js.meta.deliverPolicy
if js.meta.flowControl {
consumerConfig.DeliverPolicy = js.meta.internalDeliverPolicy
if js.meta.FlowControl {
consumerConfig.FlowControl = true
}
if js.meta.ackWait != 0 {
consumerConfig.AckWait = js.meta.ackWait
if js.meta.AckWait != 0 {
consumerConfig.AckWait = js.meta.AckWait
}
if js.meta.maxDeliver != 0 {
consumerConfig.MaxDeliver = js.meta.maxDeliver
if js.meta.MaxDeliver != 0 {
consumerConfig.MaxDeliver = js.meta.MaxDeliver
}
if len(js.meta.backOff) != 0 {
consumerConfig.BackOff = js.meta.backOff
if len(js.meta.BackOff) != 0 {
consumerConfig.BackOff = js.meta.BackOff
}
if js.meta.maxAckPending != 0 {
consumerConfig.MaxAckPending = js.meta.maxAckPending
if js.meta.MaxAckPending != 0 {
consumerConfig.MaxAckPending = js.meta.MaxAckPending
}
if js.meta.replicas != 0 {
consumerConfig.Replicas = js.meta.replicas
if js.meta.Replicas != 0 {
consumerConfig.Replicas = js.meta.Replicas
}
if js.meta.memoryStorage {
if js.meta.MemoryStorage {
consumerConfig.MemoryStorage = true
}
if js.meta.rateLimit != 0 {
consumerConfig.RateLimit = js.meta.rateLimit
if js.meta.RateLimit != 0 {
consumerConfig.RateLimit = js.meta.RateLimit
}
if js.meta.heartbeat != 0 {
consumerConfig.Heartbeat = js.meta.heartbeat
if js.meta.Heartbeat != 0 {
consumerConfig.Heartbeat = js.meta.Heartbeat
}
consumerConfig.AckPolicy = js.meta.ackPolicy
consumerConfig.AckPolicy = js.meta.internalAckPolicy
consumerConfig.FilterSubject = req.Topic
natsHandler := func(m *nats.Msg) {
@ -211,7 +213,7 @@ func (js *jetstreamPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRe
if err != nil {
js.l.Errorf("Error processing JetStream message %s/%d: %v", m.Subject, jsm.Sequence, err)
if js.meta.ackPolicy == nats.AckExplicitPolicy || js.meta.ackPolicy == nats.AckAllPolicy {
if js.meta.internalAckPolicy == nats.AckExplicitPolicy || js.meta.internalAckPolicy == nats.AckAllPolicy {
nakErr := m.Nak()
if nakErr != nil {
js.l.Errorf("Error while sending NAK for JetStream message %s/%d: %v", m.Subject, jsm.Sequence, nakErr)
@ -221,7 +223,7 @@ func (js *jetstreamPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRe
return
}
if js.meta.ackPolicy == nats.AckExplicitPolicy || js.meta.ackPolicy == nats.AckAllPolicy {
if js.meta.internalAckPolicy == nats.AckExplicitPolicy || js.meta.internalAckPolicy == nats.AckAllPolicy {
err = m.Ack()
if err != nil {
js.l.Errorf("Error while sending ACK for JetStream message %s/%d: %v", m.Subject, jsm.Sequence, err)
@ -230,7 +232,7 @@ func (js *jetstreamPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRe
}
var err error
streamName := js.meta.streamName
streamName := js.meta.StreamName
if streamName == "" {
streamName, err = js.jsc.StreamNameBySubject(req.Topic)
if err != nil {
@ -244,9 +246,9 @@ func (js *jetstreamPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRe
return err
}
if queue := js.meta.queueGroupName; queue != "" {
if queue := js.meta.QueueGroupName; queue != "" {
js.l.Debugf("nats: subscribed to subject %s with queue group %s",
req.Topic, js.meta.queueGroupName)
req.Topic, js.meta.QueueGroupName)
subscription, err = js.jsc.QueueSubscribe(req.Topic, queue, natsHandler, nats.Bind(streamName, consumerInfo.Name))
} else {
js.l.Debugf("nats: subscribed to subject %s", req.Topic)
@ -292,3 +294,11 @@ func sigHandler(seedKey string, nonce []byte) ([]byte, error) {
sig, _ := kp.Sign(nonce)
return sig, nil
}
// GetComponentMetadata returns the metadata of the component.
func (js *jetstreamPubSub) GetComponentMetadata() map[string]string {
metadataStruct := metadata{}
metadataInfo := map[string]string{}
mdutils.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, mdutils.PubSubType)
return metadataInfo
}

View File

@ -15,166 +15,105 @@ package jetstream
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/nats-io/nats.go"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
)
type metadata struct {
natsURL string
NatsURL string `mapstructure:"natsURL"`
jwt string
seedKey string
token string
Jwt string `mapstructure:"jwt"`
SeedKey string `mapstructure:"seedKey"`
Token string `mapstructure:"token"`
tlsClientCert string
tlsClientKey string
TLSClientCert string `mapstructure:"tls_client_cert"`
TLSClientKey string `mapstructure:"tls_client_key"`
name string
streamName string
durableName string
queueGroupName string
startSequence uint64
startTime time.Time
flowControl bool
ackWait time.Duration
maxDeliver int
backOff []time.Duration
maxAckPending int
replicas int
memoryStorage bool
rateLimit uint64
heartbeat time.Duration
deliverPolicy nats.DeliverPolicy
ackPolicy nats.AckPolicy
domain string
apiPrefix string
Name string `mapstructure:"name"`
StreamName string `mapstructure:"streamName"`
DurableName string `mapstructure:"durableName"`
QueueGroupName string `mapstructure:"queueGroupName"`
StartSequence uint64 `mapstructure:"startSequence"`
StartTime *uint64 `mapstructure:"startTime"`
internalStartTime time.Time `mapstructure:"-"`
FlowControl bool `mapstructure:"flowControl"`
AckWait time.Duration `mapstructure:"ackWait"`
MaxDeliver int `mapstructure:"maxDeliver"`
BackOff []time.Duration `mapstructure:"backOff"`
MaxAckPending int `mapstructure:"maxAckPending"`
Replicas int `mapstructure:"replicas"`
MemoryStorage bool `mapstructure:"memoryStorage"`
RateLimit uint64 `mapstructure:"rateLimit"`
Heartbeat time.Duration `mapstructure:"heartbeat"`
DeliverPolicy string `mapstructure:"deliverPolicy"`
internalDeliverPolicy nats.DeliverPolicy `mapstructure:"-"`
AckPolicy string `mapstructure:"ackPolicy"`
internalAckPolicy nats.AckPolicy `mapstructure:"-"`
Domain string `mapstructure:"domain"`
APIPrefix string `mapstructure:"apiPrefix"`
}
func parseMetadata(psm pubsub.Metadata) (metadata, error) {
var m metadata
if v, ok := psm.Properties["natsURL"]; ok && v != "" {
m.natsURL = v
} else {
contribMetadata.DecodeMetadata(psm.Properties, &m)
if m.NatsURL == "" {
return metadata{}, fmt.Errorf("missing nats URL")
}
m.token = psm.Properties["token"]
m.jwt = psm.Properties["jwt"]
m.seedKey = psm.Properties["seedKey"]
if m.jwt != "" && m.seedKey == "" {
if m.Jwt != "" && m.SeedKey == "" {
return metadata{}, fmt.Errorf("missing seed key")
}
if m.jwt == "" && m.seedKey != "" {
if m.Jwt == "" && m.SeedKey != "" {
return metadata{}, fmt.Errorf("missing jwt")
}
m.tlsClientCert = psm.Properties["tls_client_cert"]
m.tlsClientKey = psm.Properties["tls_client_key"]
if m.tlsClientCert != "" && m.tlsClientKey == "" {
if m.TLSClientCert != "" && m.TLSClientKey == "" {
return metadata{}, fmt.Errorf("missing tls client key")
}
if m.tlsClientCert == "" && m.tlsClientKey != "" {
if m.TLSClientCert == "" && m.TLSClientKey != "" {
return metadata{}, fmt.Errorf("missing tls client cert")
}
if m.name = psm.Properties["name"]; m.name == "" {
m.name = "dapr.io - pubsub.jetstream"
if m.Name == "" {
m.Name = "dapr.io - pubsub.jetstream"
}
m.durableName = psm.Properties["durableName"]
m.queueGroupName = psm.Properties["queueGroupName"]
if v, err := strconv.ParseUint(psm.Properties["startSequence"], 10, 64); err == nil {
m.startSequence = v
if m.StartTime != nil {
m.internalStartTime = time.Unix(int64(*m.StartTime), 0)
}
if v, err := strconv.ParseInt(psm.Properties["startTime"], 10, 64); err == nil {
m.startTime = time.Unix(v, 0)
}
if v, err := strconv.ParseBool(psm.Properties["flowControl"]); err == nil {
m.flowControl = v
}
if v, err := time.ParseDuration(psm.Properties["ackWait"]); err == nil {
m.ackWait = v
}
if v, err := strconv.Atoi(psm.Properties["maxDeliver"]); err == nil {
m.maxDeliver = v
}
backOffSlice := strings.Split(psm.Properties["backOff"], ",")
var backOff []time.Duration
for _, item := range backOffSlice {
trimmed := strings.TrimSpace(item)
if duration, err := time.ParseDuration(trimmed); err == nil {
backOff = append(backOff, duration)
}
}
m.backOff = backOff
if v, err := strconv.Atoi(psm.Properties["maxAckPending"]); err == nil {
m.maxAckPending = v
}
if v, err := strconv.Atoi(psm.Properties["replicas"]); err == nil {
m.replicas = v
}
if v, err := strconv.ParseBool(psm.Properties["memoryStorage"]); err == nil {
m.memoryStorage = v
}
if v, err := strconv.ParseUint(psm.Properties["rateLimit"], 10, 64); err == nil {
m.rateLimit = v
}
if v, err := time.ParseDuration(psm.Properties["heartbeat"]); err == nil {
m.heartbeat = v
}
if domain := psm.Properties["domain"]; domain != "" {
m.domain = domain
}
if apiPrefix := psm.Properties["apiPrefix"]; apiPrefix != "" {
m.apiPrefix = apiPrefix
}
deliverPolicy := psm.Properties["deliverPolicy"]
switch deliverPolicy {
switch m.DeliverPolicy {
case "all", "":
m.deliverPolicy = nats.DeliverAllPolicy
m.internalDeliverPolicy = nats.DeliverAllPolicy
case "last":
m.deliverPolicy = nats.DeliverLastPolicy
m.internalDeliverPolicy = nats.DeliverLastPolicy
case "new":
m.deliverPolicy = nats.DeliverNewPolicy
m.internalDeliverPolicy = nats.DeliverNewPolicy
case "sequence":
m.deliverPolicy = nats.DeliverByStartSequencePolicy
m.internalDeliverPolicy = nats.DeliverByStartSequencePolicy
case "time":
m.deliverPolicy = nats.DeliverByStartTimePolicy
m.internalDeliverPolicy = nats.DeliverByStartTimePolicy
default:
return metadata{}, fmt.Errorf("deliver policy %s is not one of: all, last, new, sequence, time", deliverPolicy)
return metadata{}, fmt.Errorf("deliver policy %s is not one of: all, last, new, sequence, time", m.DeliverPolicy)
}
m.streamName = psm.Properties["streamName"]
switch psm.Properties["ackPolicy"] {
switch m.AckPolicy {
case "explicit":
m.ackPolicy = nats.AckExplicitPolicy
m.internalAckPolicy = nats.AckExplicitPolicy
case "all":
m.ackPolicy = nats.AckAllPolicy
m.internalAckPolicy = nats.AckAllPolicy
case "none":
m.ackPolicy = nats.AckNonePolicy
m.internalAckPolicy = nats.AckNonePolicy
default:
m.ackPolicy = nats.AckExplicitPolicy
m.internalAckPolicy = nats.AckExplicitPolicy
}
return m, nil

View File

@ -22,6 +22,7 @@ import (
mdata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/ptr"
)
func TestParseMetadata(t *testing.T) {
@ -54,24 +55,25 @@ func TestParseMetadata(t *testing.T) {
},
}},
want: metadata{
natsURL: "nats://localhost:4222",
name: "myName",
durableName: "myDurable",
queueGroupName: "myQueue",
startSequence: 1,
startTime: time.Unix(1629328511, 0),
flowControl: true,
ackWait: 2 * time.Second,
maxDeliver: 10,
backOff: []time.Duration{time.Millisecond * 500, time.Second * 2, time.Second * 10},
maxAckPending: 5000,
replicas: 3,
memoryStorage: true,
rateLimit: 20000,
heartbeat: time.Second * 1,
deliverPolicy: nats.DeliverAllPolicy,
ackPolicy: nats.AckExplicitPolicy,
domain: "hub",
NatsURL: "nats://localhost:4222",
Name: "myName",
DurableName: "myDurable",
QueueGroupName: "myQueue",
StartSequence: 1,
StartTime: ptr.Of(uint64(1629328511)),
internalStartTime: time.Unix(1629328511, 0),
FlowControl: true,
AckWait: 2 * time.Second,
MaxDeliver: 10,
BackOff: []time.Duration{time.Millisecond * 500, time.Second * 2, time.Second * 10},
MaxAckPending: 5000,
Replicas: 3,
MemoryStorage: true,
RateLimit: 20000,
Heartbeat: time.Second * 1,
internalDeliverPolicy: nats.DeliverAllPolicy,
internalAckPolicy: nats.AckExplicitPolicy,
Domain: "hub",
},
expectErr: false,
},
@ -101,25 +103,28 @@ func TestParseMetadata(t *testing.T) {
},
}},
want: metadata{
natsURL: "nats://localhost:4222",
name: "myName",
durableName: "myDurable",
queueGroupName: "myQueue",
startSequence: 5,
startTime: time.Unix(1629328511, 0),
flowControl: true,
ackWait: 2 * time.Second,
maxDeliver: 10,
backOff: []time.Duration{time.Millisecond * 500, time.Second * 2, time.Second * 10},
maxAckPending: 5000,
replicas: 3,
memoryStorage: true,
rateLimit: 20000,
heartbeat: time.Second * 1,
token: "myToken",
deliverPolicy: nats.DeliverByStartSequencePolicy,
ackPolicy: nats.AckAllPolicy,
apiPrefix: "HUB",
NatsURL: "nats://localhost:4222",
Name: "myName",
DurableName: "myDurable",
QueueGroupName: "myQueue",
StartSequence: 5,
StartTime: ptr.Of(uint64(1629328511)),
internalStartTime: time.Unix(1629328511, 0),
FlowControl: true,
AckWait: 2 * time.Second,
MaxDeliver: 10,
BackOff: []time.Duration{time.Millisecond * 500, time.Second * 2, time.Second * 10},
MaxAckPending: 5000,
Replicas: 3,
MemoryStorage: true,
RateLimit: 20000,
Heartbeat: time.Second * 1,
Token: "myToken",
DeliverPolicy: "sequence",
AckPolicy: "all",
internalDeliverPolicy: nats.DeliverByStartSequencePolicy,
internalAckPolicy: nats.AckAllPolicy,
APIPrefix: "HUB",
},
expectErr: false,
},

View File

@ -16,6 +16,7 @@ package kafka
import (
"context"
"errors"
"reflect"
"sync"
"sync/atomic"
@ -23,6 +24,7 @@ import (
"github.com/dapr/components-contrib/internal/component/kafka"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
)
@ -173,3 +175,11 @@ func adaptBulkHandler(handler pubsub.BulkHandler) kafka.BulkEventHandler {
})
}
}
// GetComponentMetadata returns the metadata of the component.
func (p *PubSub) GetComponentMetadata() map[string]string {
metadataStruct := kafka.KafkaMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -3,16 +3,18 @@ package kubemq
import (
"context"
"fmt"
"reflect"
"time"
"github.com/google/uuid"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
type kubeMQ struct {
metadata *metadata
metadata *kubemqMetadata
logger logger.Logger
eventsClient *kubeMQEvents
eventStoreClient *kubeMQEventStore
@ -31,7 +33,7 @@ func (k *kubeMQ) Init(_ context.Context, metadata pubsub.Metadata) error {
return err
}
k.metadata = meta
if meta.isStore {
if meta.IsStore {
k.eventStoreClient = newKubeMQEventsStore(k.logger)
_ = k.eventStoreClient.Init(meta)
} else {
@ -46,7 +48,7 @@ func (k *kubeMQ) Features() []pubsub.Feature {
}
func (k *kubeMQ) Publish(_ context.Context, req *pubsub.PublishRequest) error {
if k.metadata.isStore {
if k.metadata.IsStore {
return k.eventStoreClient.Publish(req)
} else {
return k.eventsClient.Publish(req)
@ -54,7 +56,7 @@ func (k *kubeMQ) Publish(_ context.Context, req *pubsub.PublishRequest) error {
}
func (k *kubeMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
if k.metadata.isStore {
if k.metadata.IsStore {
return k.eventStoreClient.Subscribe(ctx, req, handler)
} else {
return k.eventsClient.Subscribe(ctx, req, handler)
@ -62,7 +64,7 @@ func (k *kubeMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
}
func (k *kubeMQ) Close() error {
if k.metadata.isStore {
if k.metadata.IsStore {
return k.eventStoreClient.Close()
} else {
return k.eventsClient.Close()
@ -76,3 +78,11 @@ func getRandomID() string {
}
return randomUUID.String()
}
// GetComponentMetadata returns the metadata of the component.
func (k *kubeMQ) GetComponentMetadata() map[string]string {
metadataStruct := &kubemqMetadata{}
metadataInfo := map[string]string{}
contribMetadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, contribMetadata.PubSubType)
return metadataInfo
}

View File

@ -20,7 +20,7 @@ type kubemqEventsClient interface {
type kubeMQEvents struct {
lock sync.RWMutex
client kubemqEventsClient
metadata *metadata
metadata *kubemqMetadata
logger logger.Logger
publishFunc func(event *kubemq.Event) error
resultChan chan error
@ -54,16 +54,16 @@ func (k *kubeMQEvents) init() error {
k.lock.Lock()
defer k.lock.Unlock()
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
clientID := k.metadata.clientID
clientID := k.metadata.ClientID
if clientID == "" {
clientID = getRandomID()
}
client, err := kubemq.NewEventsClient(k.ctx,
kubemq.WithAddress(k.metadata.host, k.metadata.port),
kubemq.WithAddress(k.metadata.internalHost, k.metadata.internalPort),
kubemq.WithClientId(clientID),
kubemq.WithTransportType(kubemq.TransportTypeGRPC),
kubemq.WithCheckConnection(true),
kubemq.WithAuthToken(k.metadata.authToken),
kubemq.WithAuthToken(k.metadata.AuthToken),
kubemq.WithAutoReconnect(true),
kubemq.WithReconnectInterval(time.Second))
if err != nil {
@ -80,7 +80,7 @@ func (k *kubeMQEvents) init() error {
return nil
}
func (k *kubeMQEvents) Init(meta *metadata) error {
func (k *kubeMQEvents) Init(meta *kubemqMetadata) error {
k.metadata = meta
_ = k.init()
return nil
@ -107,7 +107,7 @@ func (k *kubeMQEvents) Publish(req *pubsub.PublishRequest) error {
Channel: req.Topic,
Metadata: "",
Body: req.Data,
ClientId: k.metadata.clientID,
ClientId: k.metadata.ClientID,
Tags: map[string]string{},
}
if err := k.publishFunc(event); err != nil {
@ -125,14 +125,14 @@ func (k *kubeMQEvents) Subscribe(ctx context.Context, req pubsub.SubscribeReques
if err := k.init(); err != nil {
return err
}
clientID := k.metadata.clientID
clientID := k.metadata.ClientID
if clientID == "" {
clientID = getRandomID()
}
k.logger.Debugf("kubemq pub/sub: subscribing to %s", req.Topic)
err := k.client.Subscribe(ctx, &kubemq.EventsSubscription{
Channel: req.Topic,
Group: k.metadata.group,
Group: k.metadata.Group,
ClientId: clientID,
}, func(event *kubemq.Event, err error) {
if err != nil {
@ -149,7 +149,7 @@ func (k *kubeMQEvents) Subscribe(ctx context.Context, req pubsub.SubscribeReques
if err := handler(k.ctx, msg); err != nil {
k.logger.Errorf("kubemq events pub/sub error: error handling message from topic '%s', %s", req.Topic, err.Error())
if k.metadata.disableReDelivery {
if k.metadata.DisableReDelivery {
return
}
if err := k.Publish(&pubsub.PublishRequest{

View File

@ -122,13 +122,13 @@ func Test_kubeMQEvents_Publish(t *testing.T) {
setResultError(tt.resultError).
setPublishError(tt.publishErr)
k.isInitialized = true
k.metadata = &metadata{
host: "",
port: 0,
clientID: "some-client-id",
authToken: "",
group: "",
isStore: false,
k.metadata = &kubemqMetadata{
internalHost: "",
internalPort: 0,
ClientID: "some-client-id",
AuthToken: "",
Group: "",
IsStore: false,
}
if tt.timeout > 0 {
k.waitForResultTimeout = tt.timeout - 1*time.Second
@ -185,13 +185,13 @@ func Test_kubeMQEvents_Subscribe(t *testing.T) {
k.client = newKubemqEventsMock().
setSubscribeError(tt.subscribeError)
k.isInitialized = true
k.metadata = &metadata{
host: "",
port: 0,
clientID: "some-client-id",
authToken: "",
group: "",
isStore: false,
k.metadata = &kubemqMetadata{
internalHost: "",
internalPort: 0,
ClientID: "some-client-id",
AuthToken: "",
Group: "",
IsStore: false,
}
err := k.Subscribe(k.ctx, pubsub.SubscribeRequest{Topic: "some-topic"}, tt.subscribeHandler)
if tt.wantErr {

View File

@ -22,7 +22,7 @@ type kubemqEventsStoreClient interface {
type kubeMQEventStore struct {
lock sync.RWMutex
client kubemqEventsStoreClient
metadata *metadata
metadata *kubemqMetadata
logger logger.Logger
publishFunc func(msg *kubemq.EventStore) error
resultChan chan *kubemq.EventStoreResult
@ -56,16 +56,16 @@ func (k *kubeMQEventStore) init() error {
k.lock.Lock()
defer k.lock.Unlock()
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
clientID := k.metadata.clientID
clientID := k.metadata.ClientID
if clientID == "" {
clientID = getRandomID()
}
client, err := kubemq.NewEventsStoreClient(k.ctx,
kubemq.WithAddress(k.metadata.host, k.metadata.port),
kubemq.WithAddress(k.metadata.internalHost, k.metadata.internalPort),
kubemq.WithClientId(clientID),
kubemq.WithTransportType(kubemq.TransportTypeGRPC),
kubemq.WithCheckConnection(true),
kubemq.WithAuthToken(k.metadata.authToken),
kubemq.WithAuthToken(k.metadata.AuthToken),
kubemq.WithAutoReconnect(true),
kubemq.WithReconnectInterval(time.Second))
if err != nil {
@ -82,7 +82,7 @@ func (k *kubeMQEventStore) init() error {
return nil
}
func (k *kubeMQEventStore) Init(meta *metadata) error {
func (k *kubeMQEventStore) Init(meta *kubemqMetadata) error {
k.metadata = meta
_ = k.init()
return nil
@ -112,7 +112,7 @@ func (k *kubeMQEventStore) Publish(req *pubsub.PublishRequest) error {
Channel: req.Topic,
Metadata: "",
Body: req.Data,
ClientId: k.metadata.clientID,
ClientId: k.metadata.ClientID,
Tags: map[string]string{},
}
if err := k.publishFunc(event); err != nil {
@ -138,7 +138,7 @@ func (k *kubeMQEventStore) Subscribe(ctx context.Context, req pubsub.SubscribeRe
if err := k.init(); err != nil {
return err
}
clientID := k.metadata.clientID
clientID := k.metadata.ClientID
if clientID == "" {
clientID = getRandomID()
}
@ -146,7 +146,7 @@ func (k *kubeMQEventStore) Subscribe(ctx context.Context, req pubsub.SubscribeRe
k.logger.Debugf("kubemq pub/sub: subscribing to %s", req.Topic)
err := k.client.Subscribe(ctx, &kubemq.EventsStoreSubscription{
Channel: req.Topic,
Group: k.metadata.group,
Group: k.metadata.Group,
ClientId: clientID,
SubscriptionType: kubemq.StartFromNewEvents(),
}, func(event *kubemq.EventStoreReceive, err error) {
@ -166,7 +166,7 @@ func (k *kubeMQEventStore) Subscribe(ctx context.Context, req pubsub.SubscribeRe
if err := handler(ctx, msg); err != nil {
k.logger.Errorf("kubemq pub/sub error: error handling message from topic '%s', %s, resending...", req.Topic, err.Error())
if k.metadata.disableReDelivery {
if k.metadata.DisableReDelivery {
return
}
if err := k.Publish(&pubsub.PublishRequest{

View File

@ -145,13 +145,13 @@ func Test_kubeMQEventsStore_Publish(t *testing.T) {
setResultError(tt.resultError).
setPublishError(tt.publishErr)
k.isInitialized = true
k.metadata = &metadata{
host: "",
port: 0,
clientID: "some-client-id",
authToken: "",
group: "",
isStore: true,
k.metadata = &kubemqMetadata{
internalHost: "",
internalPort: 0,
ClientID: "some-client-id",
AuthToken: "",
Group: "",
IsStore: true,
}
if tt.timeout > 0 {
k.waitForResultTimeout = tt.timeout - 1*time.Second
@ -208,13 +208,13 @@ func Test_kubeMQkubeMQEventsStore_Subscribe(t *testing.T) {
k.client = newKubemqEventsStoreMock().
setSubscribeError(tt.subscribeError)
k.isInitialized = true
k.metadata = &metadata{
host: "",
port: 0,
clientID: "some-client-id",
authToken: "",
group: "",
isStore: true,
k.metadata = &kubemqMetadata{
internalHost: "",
internalPort: 0,
ClientID: "some-client-id",
AuthToken: "",
Group: "",
IsStore: true,
}
err := k.Subscribe(k.ctx, pubsub.SubscribeRequest{Topic: "some-topic"}, tt.subscribeHandler)
if tt.wantErr {

View File

@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
func Test_kubeMQ_Close(t *testing.T) {
type fields struct {
metadata *metadata
metadata *kubemqMetadata
logger logger.Logger
eventsClient *kubeMQEvents
eventStoreClient *kubeMQEventStore
@ -119,8 +119,8 @@ func Test_kubeMQ_Close(t *testing.T) {
{
name: "close events client",
fields: fields{
metadata: &metadata{
isStore: false,
metadata: &kubemqMetadata{
IsStore: false,
},
eventsClient: getMockEventsClient(),
eventStoreClient: nil,
@ -130,8 +130,8 @@ func Test_kubeMQ_Close(t *testing.T) {
{
name: "close events store client",
fields: fields{
metadata: &metadata{
isStore: true,
metadata: &kubemqMetadata{
IsStore: true,
},
eventsClient: nil,
eventStoreClient: getMockEventsStoreClient(),

View File

@ -5,17 +5,19 @@ import (
"strconv"
"strings"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
)
type metadata struct {
host string
port int
clientID string
authToken string
group string
isStore bool
disableReDelivery bool
type kubemqMetadata struct {
Address string `mapstructure:"address"`
internalHost string `mapstructure:"-"`
internalPort int `mapstructure:"-"`
ClientID string `mapstructure:"clientID"`
AuthToken string `mapstructure:"authToken"`
Group string `mapstructure:"group"`
IsStore bool `mapstructure:"store"`
DisableReDelivery bool `mapstructure:"disableReDelivery"`
}
func parseAddress(address string) (string, int, error) {
@ -38,43 +40,24 @@ func parseAddress(address string) (string, int, error) {
}
// createMetadata creates a new instance from the pubsub metadata
func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) {
result := &metadata{}
if val, found := pubSubMetadata.Properties["address"]; found && val != "" {
func createMetadata(pubSubMetadata pubsub.Metadata) (*kubemqMetadata, error) {
result := &kubemqMetadata{
IsStore: true,
}
err := metadata.DecodeMetadata(pubSubMetadata.Properties, result)
if err != nil {
return nil, err
}
if result.Address != "" {
var err error
result.host, result.port, err = parseAddress(val)
result.internalHost, result.internalPort, err = parseAddress(result.Address)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("invalid kubeMQ address, address is empty")
}
if val, found := pubSubMetadata.Properties["clientID"]; found && val != "" {
result.clientID = val
}
if val, found := pubSubMetadata.Properties["authToken"]; found && val != "" {
result.authToken = val
}
if val, found := pubSubMetadata.Properties["group"]; found && val != "" {
result.group = val
}
result.isStore = true
if val, found := pubSubMetadata.Properties["store"]; found && val != "" {
switch val {
case "false":
result.isStore = false
case "true":
result.isStore = true
default:
return nil, fmt.Errorf("invalid kubeMQ store value, store can be true or false")
}
}
if val, found := pubSubMetadata.Properties["disableReDelivery"]; found && val != "" {
if val == "true" {
result.disableReDelivery = true
}
}
return result, nil
}

View File

@ -13,7 +13,7 @@ func Test_createMetadata(t *testing.T) {
tests := []struct {
name string
meta pubsub.Metadata
want *metadata
want *kubemqMetadata
wantErr bool
}{
{
@ -32,14 +32,15 @@ func Test_createMetadata(t *testing.T) {
},
},
},
want: &metadata{
host: "localhost",
port: 50000,
clientID: "clientID",
authToken: "authToken",
group: "group",
isStore: true,
disableReDelivery: true,
want: &kubemqMetadata{
Address: "localhost:50000",
internalHost: "localhost",
internalPort: 50000,
ClientID: "clientID",
AuthToken: "authToken",
Group: "group",
IsStore: true,
DisableReDelivery: true,
},
wantErr: false,
},
@ -55,13 +56,14 @@ func Test_createMetadata(t *testing.T) {
},
},
},
want: &metadata{
host: "localhost",
port: 50000,
clientID: "clientID",
authToken: "authToken",
group: "",
isStore: false,
want: &kubemqMetadata{
Address: "localhost:50000",
internalHost: "localhost",
internalPort: 50000,
ClientID: "clientID",
AuthToken: "authToken",
Group: "",
IsStore: false,
},
wantErr: false,
},
@ -78,13 +80,14 @@ func Test_createMetadata(t *testing.T) {
},
},
},
want: &metadata{
host: "localhost",
port: 50000,
clientID: "clientID",
authToken: "",
group: "group",
isStore: true,
want: &kubemqMetadata{
Address: "localhost:50000",
internalHost: "localhost",
internalPort: 50000,
ClientID: "clientID",
AuthToken: "",
Group: "group",
IsStore: true,
},
wantErr: false,
},
@ -140,20 +143,6 @@ func Test_createMetadata(t *testing.T) {
want: nil,
wantErr: true,
},
{
name: "create invalid metadata with bad store info",
meta: pubsub.Metadata{
Base: mdata.Base{
Properties: map[string]string{
"address": "localhost:50000",
"clientID": "clientID",
"store": "bad",
},
},
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {

View File

@ -16,21 +16,20 @@ package mqtt
import (
"errors"
"fmt"
"strconv"
"time"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
type metadata struct {
pubsub.TLSProperties
url string
consumerID string
qos byte
retain bool
cleanSession bool
type mqttMetadata struct {
pubsub.TLSProperties `mapstructure:",squash"`
URL string `mapstructure:"url"`
ConsumerID string `mapstructure:"consumerID"`
Qos byte `mapstructure:"qos"`
Retain bool `mapstructure:"retain"`
CleanSession bool `mapstructure:"cleanSession"`
}
const (
@ -48,44 +47,32 @@ const (
defaultCleanSession = false
)
func parseMQTTMetaData(md pubsub.Metadata, log logger.Logger) (*metadata, error) {
m := metadata{}
func parseMQTTMetaData(md pubsub.Metadata, log logger.Logger) (*mqttMetadata, error) {
m := mqttMetadata{
Qos: defaultQOS,
CleanSession: defaultCleanSession,
}
err := metadata.DecodeMetadata(md.Properties, &m)
if err != nil {
return &m, fmt.Errorf("mqtt pubsub error: %w", err)
}
// required configuration settings
if val, ok := md.Properties[mqttURL]; ok && val != "" {
m.url = val
} else {
if m.URL == "" {
return &m, errors.New("missing url")
}
// optional configuration settings
m.qos = defaultQOS
if val, ok := md.Properties[mqttQOS]; ok && val != "" {
qosInt, err := strconv.Atoi(val)
if err != nil || qosInt < 0 || qosInt > 7 {
return &m, fmt.Errorf("invalid qos %s: %w", val, err)
}
m.qos = byte(qosInt)
}
m.retain = defaultRetain
if val, ok := md.Properties[mqttRetain]; ok && val != "" {
m.retain = utils.IsTruthy(val)
if m.Qos > 7 { // bytes cannot be less than 0
return &m, fmt.Errorf("invalid qos %d: %w", m.Qos, err)
}
// Note: the runtime sets the default value to the Dapr app ID if empty
if val, ok := md.Properties[mqttConsumerID]; ok && val != "" {
m.consumerID = val
} else {
if m.ConsumerID == "" {
return &m, errors.New("missing consumerID")
}
m.cleanSession = defaultCleanSession
if val, ok := md.Properties[mqttCleanSession]; ok && val != "" {
m.cleanSession = utils.IsTruthy(val)
}
var err error
m.TLSProperties, err = pubsub.TLS(md.Properties)
if err != nil {
return &m, fmt.Errorf("invalid TLS configuration: %w", err)

View File

@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
@ -29,6 +30,7 @@ import (
"golang.org/x/exp/maps"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -41,7 +43,7 @@ const (
// mqttPubSub type allows sending and receiving data to/from MQTT broker.
type mqttPubSub struct {
conn mqtt.Client
metadata *metadata
metadata *mqttMetadata
logger logger.Logger
topics map[string]mqttPubSubSubscription
subscribingLock sync.RWMutex
@ -99,7 +101,7 @@ func (m *mqttPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) (e
// m.logger.Debugf("mqtt publishing topic %s with data: %v", req.Topic, req.Data)
m.logger.Debugf("mqtt publishing topic %s", req.Topic)
retain := m.metadata.retain
retain := m.metadata.Retain
if val, ok := req.Metadata[mqttRetain]; ok && val != "" {
retain, err = strconv.ParseBool(val)
if err != nil {
@ -107,7 +109,7 @@ func (m *mqttPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) (e
}
}
token := m.conn.Publish(req.Topic, m.metadata.qos, retain, req.Data)
token := m.conn.Publish(req.Topic, m.metadata.Qos, retain, req.Data)
ctx, cancel := context.WithTimeout(ctx, defaultWait)
defer cancel()
select {
@ -146,7 +148,7 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest,
// Add the topic then start the subscription
m.addTopic(topic, handler)
token := m.conn.Subscribe(topic, m.metadata.qos, m.onMessage(ctx))
token := m.conn.Subscribe(topic, m.metadata.Qos, m.onMessage(ctx))
var err error
select {
case <-token.Done():
@ -163,7 +165,7 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest,
return fmt.Errorf("mqtt error from subscribe: %v", err)
}
m.logger.Infof("MQTT is subscribed to topic %s (qos: %d)", topic, m.metadata.qos)
m.logger.Infof("MQTT is subscribed to topic %s (qos: %d)", topic, m.metadata.Qos)
// Listen for context cancelation to remove the subscription
m.wg.Add(1)
@ -185,7 +187,7 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest,
// We will call Unsubscribe only if cleanSession is true or if "unsubscribeOnClose" in the request metadata is true
// Otherwise, calling this will make the broker lose the position of our subscription, which is not what we want if we are going to reconnect later
if !m.metadata.cleanSession && !unsubscribeOnClose {
if !m.metadata.CleanSession && !unsubscribeOnClose {
return
}
@ -258,7 +260,7 @@ func (m *mqttPubSub) handlerForTopic(topic string) pubsub.Handler {
}
func (m *mqttPubSub) doConnect(ctx context.Context, clientID string) (mqtt.Client, error) {
uri, err := url.Parse(m.metadata.url)
uri, err := url.Parse(m.metadata.URL)
if err != nil {
return nil, err
}
@ -287,7 +289,7 @@ func (m *mqttPubSub) connect(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, defaultWait)
defer cancel()
conn, err := m.doConnect(ctx, m.metadata.consumerID)
conn, err := m.doConnect(ctx, m.metadata.ConsumerID)
if err != nil {
return err
}
@ -299,7 +301,7 @@ func (m *mqttPubSub) connect(ctx context.Context) error {
func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions().
SetClientID(clientID).
SetCleanSession(m.metadata.cleanSession).
SetCleanSession(m.metadata.CleanSession).
// If OrderMatters is true (default), handlers must not block, which is not an option for us
SetOrderMatters(false).
// Disable automatic ACKs as we need to do it manually
@ -331,7 +333,7 @@ func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.Cl
// Create the list of topics to subscribe to
subscribeTopics := make(map[string]byte, len(m.topics))
for k := range m.topics {
subscribeTopics[k] = m.metadata.qos
subscribeTopics[k] = m.metadata.Qos
}
// Note that this is a bit unusual for a pubsub component as we're using a background context for the handler.
@ -491,3 +493,11 @@ func buildRegexForTopic(topicName string) string {
return regexStr
}
// GetComponentMetadata returns the metadata of the component.
func (m *mqttPubSub) GetComponentMetadata() map[string]string {
metadataStruct := mqttMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -227,10 +227,10 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[mqttURL], m.url)
assert.Equal(t, byte(1), m.qos)
assert.Equal(t, true, m.retain)
assert.Equal(t, false, m.cleanSession)
assert.Equal(t, fakeProperties[mqttURL], m.URL)
assert.Equal(t, byte(1), m.Qos)
assert.Equal(t, true, m.Retain)
assert.Equal(t, false, m.CleanSession)
})
t.Run("missing consumerID", func(t *testing.T) {
@ -255,7 +255,7 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.ErrorContains(t, err, "missing url")
assert.Equal(t, fakeProperties[mqttURL], m.url)
assert.Equal(t, fakeProperties[mqttURL], m.URL)
})
t.Run("qos and retain is not given", func(t *testing.T) {
@ -264,16 +264,16 @@ func TestParseMetadata(t *testing.T) {
fakeMetaData := pubsub.Metadata{
Base: mdata.Base{Properties: fakeProperties},
}
fakeMetaData.Properties[mqttQOS] = ""
fakeMetaData.Properties[mqttRetain] = ""
delete(fakeMetaData.Properties, mqttQOS)
delete(fakeMetaData.Properties, mqttRetain)
m, err := parseMQTTMetaData(fakeMetaData, log)
// assert
require.NoError(t, err)
assert.Equal(t, fakeProperties[mqttURL], m.url)
assert.Equal(t, byte(1), m.qos)
assert.Equal(t, false, m.retain)
assert.Equal(t, fakeProperties[mqttURL], m.URL)
assert.Equal(t, byte(1), m.Qos)
assert.Equal(t, false, m.Retain)
})
t.Run("invalid ca certificate", func(t *testing.T) {
@ -624,7 +624,7 @@ func Test_buildRegexForTopic(t *testing.T) {
func Test_mqttPubSub_Publish(t *testing.T) {
type fields struct {
logger logger.Logger
metadata *metadata
metadata *mqttMetadata
ctx context.Context
}
type args struct {
@ -642,8 +642,8 @@ func Test_mqttPubSub_Publish(t *testing.T) {
fields: fields{
logger: logger.NewLogger("mqtt-test"),
ctx: context.Background(),
metadata: &metadata{
retain: true,
metadata: &mqttMetadata{
Retain: true,
},
},
args: args{
@ -668,8 +668,8 @@ func Test_mqttPubSub_Publish(t *testing.T) {
fields: fields{
logger: logger.NewLogger("mqtt-test"),
ctx: context.Background(),
metadata: &metadata{
retain: true,
metadata: &mqttMetadata{
Retain: true,
},
},
args: args{

View File

@ -19,20 +19,20 @@ import (
"github.com/dapr/components-contrib/pubsub"
)
type metadata struct {
natsURL string
natsStreamingClusterID string
subscriptionType string
natsQueueGroupName string
durableSubscriptionName string
startAtSequence uint64
startWithLastReceived string
deliverNew string
deliverAll string
startAtTimeDelta time.Duration
startAtTime string
startAtTimeFormat string
ackWaitTime time.Duration
maxInFlight uint64
concurrencyMode pubsub.ConcurrencyMode
type natsMetadata struct {
NatsURL string
NatsStreamingClusterID string
SubscriptionType string
NatsQueueGroupName string `mapstructure:"consumerId"`
DurableSubscriptionName string
StartAtSequence *uint64
StartWithLastReceived string
DeliverNew string
DeliverAll string
StartAtTimeDelta time.Duration
StartAtTime string
StartAtTimeFormat string
AckWaitTime time.Duration
MaxInFlight *uint64
ConcurrencyMode pubsub.ConcurrencyMode
}

View File

@ -21,7 +21,7 @@ import (
"errors"
"fmt"
"math/rand"
"strconv"
"reflect"
"sync"
"sync/atomic"
"time"
@ -30,6 +30,7 @@ import (
stan "github.com/nats-io/stan.go"
"github.com/nats-io/stan.go/pb"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/retry"
@ -70,7 +71,7 @@ const (
)
type natsStreamingPubSub struct {
metadata metadata
metadata natsMetadata
natStreamingConn stan.Conn
logger logger.Logger
@ -87,109 +88,93 @@ func NewNATSStreamingPubSub(logger logger.Logger) pubsub.PubSub {
return &natsStreamingPubSub{logger: logger, closeCh: make(chan struct{})}
}
func parseNATSStreamingMetadata(meta pubsub.Metadata) (metadata, error) {
m := metadata{}
if val, ok := meta.Properties[natsURL]; ok && val != "" {
m.natsURL = val
} else {
func parseNATSStreamingMetadata(meta pubsub.Metadata) (natsMetadata, error) {
m := natsMetadata{}
var err error
if err = metadata.DecodeMetadata(meta.Properties, &m); err != nil {
return m, err
}
if m.NatsURL == "" {
return m, errors.New("nats-streaming error: missing nats URL")
}
if val, ok := meta.Properties[natsStreamingClusterID]; ok && val != "" {
m.natsStreamingClusterID = val
} else {
if m.NatsStreamingClusterID == "" {
return m, errors.New("nats-streaming error: missing nats streaming cluster ID")
}
if val, ok := meta.Properties[subscriptionType]; ok {
if val == subscriptionTypeTopic || val == subscriptionTypeQueueGroup {
m.subscriptionType = val
} else {
return m, errors.New("nats-streaming error: valid value for subscriptionType is topic or queue")
}
switch m.SubscriptionType {
case subscriptionTypeTopic, subscriptionTypeQueueGroup, "":
// valid values
default:
return m, errors.New("nats-streaming error: valid value for subscriptionType is topic or queue")
}
if val, ok := meta.Properties[consumerID]; ok && val != "" {
m.natsQueueGroupName = val
} else {
if m.NatsQueueGroupName == "" {
return m, errors.New("nats-streaming error: missing queue group name")
}
if val, ok := meta.Properties[durableSubscriptionName]; ok && val != "" {
m.durableSubscriptionName = val
}
if val, ok := meta.Properties[ackWaitTime]; ok && val != "" {
dur, err := time.ParseDuration(meta.Properties[ackWaitTime])
if err != nil {
return m, fmt.Errorf("nats-streaming error %s ", err)
}
m.ackWaitTime = dur
}
if val, ok := meta.Properties[maxInFlight]; ok && val != "" {
max, err := strconv.ParseUint(meta.Properties[maxInFlight], 10, 64)
if err != nil {
return m, fmt.Errorf("nats-streaming error in parsemetadata for maxInFlight: %s ", err)
}
if max < 1 {
return m, errors.New("nats-streaming error: maxInFlight should be equal to or more than 1")
}
m.maxInFlight = max
if m.MaxInFlight != nil && *m.MaxInFlight < 1 {
return m, errors.New("nats-streaming error: maxInFlight should be equal to or more than 1")
}
//nolint:nestif
// subscription options - only one can be used
if val, ok := meta.Properties[startAtSequence]; ok && val != "" {
// nats streaming accepts a uint64 as sequence
seq, err := strconv.ParseUint(meta.Properties[startAtSequence], 10, 64)
if err != nil {
return m, fmt.Errorf("nats-streaming error %s ", err)
// helper function to reset mutually exclusive options
clearValues := func(m *natsMetadata, indexToKeep int) {
if indexToKeep != 0 {
m.StartAtSequence = nil
}
if seq < 1 {
return m, errors.New("nats-streaming error: startAtSequence should be equal to or more than 1")
if indexToKeep != 1 {
m.StartWithLastReceived = ""
}
m.startAtSequence = seq
} else if val, ok := meta.Properties[startWithLastReceived]; ok {
// only valid value is true
if val == startWithLastReceivedTrue {
m.startWithLastReceived = val
} else {
return m, errors.New("nats-streaming error: valid value for startWithLastReceived is true")
if indexToKeep != 2 {
m.DeliverAll = ""
}
} else if val, ok := meta.Properties[deliverAll]; ok {
// only valid value is true
if val == deliverAllTrue {
m.deliverAll = val
} else {
return m, errors.New("nats-streaming error: valid value for deliverAll is true")
if indexToKeep != 3 {
m.DeliverNew = ""
}
} else if val, ok := meta.Properties[deliverNew]; ok {
// only valid value is true
if val == deliverNewTrue {
m.deliverNew = val
} else {
return m, errors.New("nats-streaming error: valid value for deliverNew is true")
if indexToKeep != 4 {
m.StartAtTime = ""
}
} else if val, ok := meta.Properties[startAtTimeDelta]; ok && val != "" {
dur, err := time.ParseDuration(meta.Properties[startAtTimeDelta])
if err != nil {
return m, fmt.Errorf("nats-streaming error %s ", err)
}
m.startAtTimeDelta = dur
} else if val, ok := meta.Properties[startAtTime]; ok && val != "" {
m.startAtTime = val
if val, ok := meta.Properties[startAtTimeFormat]; ok && val != "" {
m.startAtTimeFormat = val
} else {
return m, errors.New("nats-streaming error: missing value for startAtTimeFormat")
if indexToKeep != 4 {
m.StartAtTimeFormat = ""
}
}
c, err := pubsub.Concurrency(meta.Properties)
switch {
case m.StartAtSequence != nil:
if *m.StartAtSequence < 1 {
return m, errors.New("nats-streaming error: startAtSequence should be equal to or more than 1")
}
clearValues(&m, 0)
case m.StartWithLastReceived != "":
if m.StartWithLastReceived != startWithLastReceivedTrue {
return m, errors.New("nats-streaming error: valid value for startWithLastReceived is true")
}
clearValues(&m, 1)
case m.DeliverAll != "":
if m.DeliverAll != deliverAllTrue {
return m, errors.New("nats-streaming error: valid value for deliverAll is true")
}
clearValues(&m, 2)
case m.DeliverNew != "":
if m.DeliverNew != deliverNewTrue {
return m, errors.New("nats-streaming error: valid value for deliverNew is true")
}
clearValues(&m, 3)
case m.StartAtTime != "":
if m.StartAtTimeFormat == "" {
return m, errors.New("nats-streaming error: missing value for startAtTimeFormat")
}
clearValues(&m, 4)
}
m.ConcurrencyMode, err = pubsub.Concurrency(meta.Properties)
if err != nil {
return m, fmt.Errorf("nats-streaming error: can't parse %s: %s", pubsub.ConcurrencyKey, err)
}
m.concurrencyMode = c
return m, nil
}
@ -201,15 +186,15 @@ func (n *natsStreamingPubSub) Init(_ context.Context, metadata pubsub.Metadata)
n.metadata = m
clientID := genRandomString(20)
opts := []nats.Option{nats.Name(clientID)}
natsConn, err := nats.Connect(m.natsURL, opts...)
natsConn, err := nats.Connect(m.NatsURL, opts...)
if err != nil {
return fmt.Errorf("nats-streaming: error connecting to nats server at %s: %s", m.natsURL, err)
return fmt.Errorf("nats-streaming: error connecting to nats server at %s: %s", m.NatsURL, err)
}
natStreamingConn, err := stan.Connect(m.natsStreamingClusterID, clientID, stan.NatsConn(natsConn))
natStreamingConn, err := stan.Connect(m.NatsStreamingClusterID, clientID, stan.NatsConn(natsConn))
if err != nil {
return fmt.Errorf("nats-streaming: error connecting to nats streaming server %s: %s", m.natsStreamingClusterID, err)
return fmt.Errorf("nats-streaming: error connecting to nats streaming server %s: %s", m.NatsStreamingClusterID, err)
}
n.logger.Debugf("connected to natsstreaming at %s", m.natsURL)
n.logger.Debugf("connected to natsstreaming at %s", m.NatsURL)
// Default retry configuration is used if no
// backOff properties are set.
@ -263,7 +248,7 @@ func (n *natsStreamingPubSub) Subscribe(ctx context.Context, req pubsub.Subscrib
}
}
switch n.metadata.concurrencyMode {
switch n.metadata.ConcurrencyMode {
case pubsub.Single:
f()
case pubsub.Parallel:
@ -276,10 +261,10 @@ func (n *natsStreamingPubSub) Subscribe(ctx context.Context, req pubsub.Subscrib
}
var subscription stan.Subscription
if n.metadata.subscriptionType == subscriptionTypeTopic {
if n.metadata.SubscriptionType == subscriptionTypeTopic {
subscription, err = n.natStreamingConn.Subscribe(req.Topic, natsMsgHandler, natStreamingsubscriptionOptions...)
} else if n.metadata.subscriptionType == subscriptionTypeQueueGroup {
subscription, err = n.natStreamingConn.QueueSubscribe(req.Topic, n.metadata.natsQueueGroupName, natsMsgHandler, natStreamingsubscriptionOptions...)
} else if n.metadata.SubscriptionType == subscriptionTypeQueueGroup {
subscription, err = n.natStreamingConn.QueueSubscribe(req.Topic, n.metadata.NatsQueueGroupName, natsMsgHandler, natStreamingsubscriptionOptions...)
}
if err != nil {
@ -299,10 +284,10 @@ func (n *natsStreamingPubSub) Subscribe(ctx context.Context, req pubsub.Subscrib
}
}()
if n.metadata.subscriptionType == subscriptionTypeTopic {
if n.metadata.SubscriptionType == subscriptionTypeTopic {
n.logger.Debugf("nats-streaming: subscribed to subject %s", req.Topic)
} else if n.metadata.subscriptionType == subscriptionTypeQueueGroup {
n.logger.Debugf("nats-streaming: subscribed to subject %s with queue group %s", req.Topic, n.metadata.natsQueueGroupName)
} else if n.metadata.SubscriptionType == subscriptionTypeQueueGroup {
n.logger.Debugf("nats-streaming: subscribed to subject %s with queue group %s", req.Topic, n.metadata.NatsQueueGroupName)
}
return nil
@ -311,24 +296,24 @@ func (n *natsStreamingPubSub) Subscribe(ctx context.Context, req pubsub.Subscrib
func (n *natsStreamingPubSub) subscriptionOptions() ([]stan.SubscriptionOption, error) {
var options []stan.SubscriptionOption
if n.metadata.durableSubscriptionName != "" {
options = append(options, stan.DurableName(n.metadata.durableSubscriptionName))
if n.metadata.DurableSubscriptionName != "" {
options = append(options, stan.DurableName(n.metadata.DurableSubscriptionName))
}
switch {
case n.metadata.deliverNew == deliverNewTrue:
case n.metadata.DeliverNew == deliverNewTrue:
options = append(options, stan.StartAt(pb.StartPosition_NewOnly)) //nolint:nosnakecase
case n.metadata.startAtSequence >= 1: // messages index start from 1, this is a valid check
options = append(options, stan.StartAtSequence(n.metadata.startAtSequence))
case n.metadata.startWithLastReceived == startWithLastReceivedTrue:
case n.metadata.StartAtSequence != nil && *n.metadata.StartAtSequence >= 1: // messages index start from 1, this is a valid check
options = append(options, stan.StartAtSequence(*n.metadata.StartAtSequence))
case n.metadata.StartWithLastReceived == startWithLastReceivedTrue:
options = append(options, stan.StartWithLastReceived())
case n.metadata.deliverAll == deliverAllTrue:
case n.metadata.DeliverAll == deliverAllTrue:
options = append(options, stan.DeliverAllAvailable())
case n.metadata.startAtTimeDelta > (1 * time.Nanosecond): // as long as its a valid time.Duration
options = append(options, stan.StartAtTimeDelta(n.metadata.startAtTimeDelta))
case n.metadata.startAtTime != "":
if n.metadata.startAtTimeFormat != "" {
startTime, err := time.Parse(n.metadata.startAtTimeFormat, n.metadata.startAtTime)
case n.metadata.StartAtTimeDelta > (1 * time.Nanosecond): // as long as its a valid time.Duration
options = append(options, stan.StartAtTimeDelta(n.metadata.StartAtTimeDelta))
case n.metadata.StartAtTime != "":
if n.metadata.StartAtTimeFormat != "" {
startTime, err := time.Parse(n.metadata.StartAtTimeFormat, n.metadata.StartAtTime)
if err != nil {
return nil, err
}
@ -340,11 +325,11 @@ func (n *natsStreamingPubSub) subscriptionOptions() ([]stan.SubscriptionOption,
options = append(options, stan.SetManualAckMode())
// check if set the ack options.
if n.metadata.ackWaitTime > (1 * time.Nanosecond) {
options = append(options, stan.AckWait(n.metadata.ackWaitTime))
if n.metadata.AckWaitTime > (1 * time.Nanosecond) {
options = append(options, stan.AckWait(n.metadata.AckWaitTime))
}
if n.metadata.maxInFlight >= 1 {
options = append(options, stan.MaxInflight(int(n.metadata.maxInFlight)))
if n.metadata.MaxInFlight != nil && *n.metadata.MaxInFlight >= 1 {
options = append(options, stan.MaxInflight(int(*n.metadata.MaxInFlight)))
}
return options, nil
@ -376,3 +361,11 @@ func (n *natsStreamingPubSub) Close() error {
func (n *natsStreamingPubSub) Features() []pubsub.Feature {
return nil
}
// GetComponentMetadata returns the metadata of the component.
func (n *natsStreamingPubSub) GetComponentMetadata() map[string]string {
metadataStruct := natsMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -22,6 +22,7 @@ import (
mdata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/ptr"
)
func TestParseNATSStreamingForMetadataMandatoryOptionsMissing(t *testing.T) {
@ -236,18 +237,18 @@ func TestParseNATSStreamingMetadataForValidSubscriptionOptions(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, m.natsURL)
assert.NotEmpty(t, m.natsStreamingClusterID)
assert.NotEmpty(t, m.subscriptionType)
assert.NotEmpty(t, m.natsQueueGroupName)
assert.NotEmpty(t, m.concurrencyMode)
assert.NotEmpty(t, m.NatsURL)
assert.NotEmpty(t, m.NatsStreamingClusterID)
assert.NotEmpty(t, m.SubscriptionType)
assert.NotEmpty(t, m.NatsQueueGroupName)
assert.NotEmpty(t, m.ConcurrencyMode)
assert.NotEmpty(t, _test.expectedMetadataValue)
assert.Equal(t, _test.properties[natsURL], m.natsURL)
assert.Equal(t, _test.properties[natsStreamingClusterID], m.natsStreamingClusterID)
assert.Equal(t, _test.properties[subscriptionType], m.subscriptionType)
assert.Equal(t, _test.properties[consumerID], m.natsQueueGroupName)
assert.Equal(t, _test.properties[pubsub.ConcurrencyKey], string(m.concurrencyMode))
assert.Equal(t, _test.properties[natsURL], m.NatsURL)
assert.Equal(t, _test.properties[natsStreamingClusterID], m.NatsStreamingClusterID)
assert.Equal(t, _test.properties[subscriptionType], m.SubscriptionType)
assert.Equal(t, _test.properties[consumerID], m.NatsQueueGroupName)
assert.Equal(t, _test.properties[pubsub.ConcurrencyKey], string(m.ConcurrencyMode))
assert.Equal(t, _test.properties[_test.expectedMetadataName], _test.expectedMetadataValue)
})
}
@ -266,12 +267,12 @@ func TestParseNATSStreamingMetadata(t *testing.T) {
m, err := parseNATSStreamingMetadata(fakeMetaData)
assert.NoError(t, err)
assert.NotEmpty(t, m.natsURL)
assert.NotEmpty(t, m.natsStreamingClusterID)
assert.NotEmpty(t, m.natsQueueGroupName)
assert.Equal(t, fakeProperties[natsURL], m.natsURL)
assert.Equal(t, fakeProperties[natsStreamingClusterID], m.natsStreamingClusterID)
assert.Equal(t, fakeProperties[consumerID], m.natsQueueGroupName)
assert.NotEmpty(t, m.NatsURL)
assert.NotEmpty(t, m.NatsStreamingClusterID)
assert.NotEmpty(t, m.NatsQueueGroupName)
assert.Equal(t, fakeProperties[natsURL], m.NatsURL)
assert.Equal(t, fakeProperties[natsStreamingClusterID], m.NatsStreamingClusterID)
assert.Equal(t, fakeProperties[consumerID], m.NatsQueueGroupName)
})
t.Run("subscription type missing", func(t *testing.T) {
@ -314,41 +315,41 @@ func TestParseNATSStreamingMetadata(t *testing.T) {
}
m, err := parseNATSStreamingMetadata(fakeMetaData)
assert.NoError(t, err)
assert.NotEmpty(t, m.natsURL)
assert.NotEmpty(t, m.natsStreamingClusterID)
assert.NotEmpty(t, m.subscriptionType)
assert.NotEmpty(t, m.natsQueueGroupName)
assert.NotEmpty(t, m.startAtSequence)
assert.NotEmpty(t, m.NatsURL)
assert.NotEmpty(t, m.NatsStreamingClusterID)
assert.NotEmpty(t, m.SubscriptionType)
assert.NotEmpty(t, m.NatsQueueGroupName)
assert.NotEmpty(t, m.StartAtSequence)
// startWithLastReceived ignored
assert.Empty(t, m.startWithLastReceived)
assert.Empty(t, m.StartWithLastReceived)
// deliverAll will be ignored
assert.Empty(t, m.deliverAll)
assert.Empty(t, m.DeliverAll)
assert.Equal(t, fakeProperties[natsURL], m.natsURL)
assert.Equal(t, fakeProperties[natsStreamingClusterID], m.natsStreamingClusterID)
assert.Equal(t, fakeProperties[subscriptionType], m.subscriptionType)
assert.Equal(t, fakeProperties[consumerID], m.natsQueueGroupName)
assert.Equal(t, fakeProperties[startAtSequence], strconv.FormatUint(m.startAtSequence, 10))
assert.Equal(t, fakeProperties[natsURL], m.NatsURL)
assert.Equal(t, fakeProperties[natsStreamingClusterID], m.NatsStreamingClusterID)
assert.Equal(t, fakeProperties[subscriptionType], m.SubscriptionType)
assert.Equal(t, fakeProperties[consumerID], m.NatsQueueGroupName)
assert.Equal(t, fakeProperties[startAtSequence], strconv.FormatUint(*m.StartAtSequence, 10))
})
}
func TestSubscriptionOptionsForValidOptions(t *testing.T) {
type test struct {
name string
m metadata
m natsMetadata
expectedNumberOfOptions int
}
tests := []test{
{"using durableSubscriptionName", metadata{durableSubscriptionName: "foobar"}, 2},
{"durableSubscriptionName is empty", metadata{durableSubscriptionName: ""}, 1},
{"using startAtSequence", metadata{startAtSequence: uint64(42)}, 2},
{"using startWithLastReceived", metadata{startWithLastReceived: startWithLastReceivedTrue}, 2},
{"using deliverAll", metadata{deliverAll: deliverAllTrue}, 2},
{"using startAtTimeDelta", metadata{startAtTimeDelta: 1 * time.Hour}, 2},
{"using startAtTime and startAtTimeFormat", metadata{startAtTime: "Feb 3, 2013 at 7:54pm (PST)", startAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}, 2},
{"using manual ack with ackWaitTime", metadata{ackWaitTime: 30 * time.Second}, 2},
{"using manual ack with maxInFlight", metadata{maxInFlight: uint64(42)}, 2},
{"using durableSubscriptionName", natsMetadata{DurableSubscriptionName: "foobar"}, 2},
{"durableSubscriptionName is empty", natsMetadata{DurableSubscriptionName: ""}, 1},
{"using startAtSequence", natsMetadata{StartAtSequence: ptr.Of(uint64(42))}, 2},
{"using startWithLastReceived", natsMetadata{StartWithLastReceived: startWithLastReceivedTrue}, 2},
{"using deliverAll", natsMetadata{DeliverAll: deliverAllTrue}, 2},
{"using startAtTimeDelta", natsMetadata{StartAtTimeDelta: 1 * time.Hour}, 2},
{"using startAtTime and startAtTimeFormat", natsMetadata{StartAtTime: "Feb 3, 2013 at 7:54pm (PST)", StartAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}, 2},
{"using manual ack with ackWaitTime", natsMetadata{AckWaitTime: 30 * time.Second}, 2},
{"using manual ack with maxInFlight", natsMetadata{MaxInFlight: ptr.Of(uint64(42))}, 2},
}
for _, _test := range tests {
@ -365,16 +366,16 @@ func TestSubscriptionOptionsForValidOptions(t *testing.T) {
func TestSubscriptionOptionsForInvalidOptions(t *testing.T) {
type test struct {
name string
m metadata
m natsMetadata
}
tests := []test{
{"startAtSequence is less than 1", metadata{startAtSequence: uint64(0)}},
{"startWithLastReceived is other than true", metadata{startWithLastReceived: "foo"}},
{"deliverAll is other than true", metadata{deliverAll: "foo"}},
{"deliverNew is other than true", metadata{deliverNew: "foo"}},
{"startAtTime is empty", metadata{startAtTime: "", startAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}},
{"startAtTimeFormat is empty", metadata{startAtTime: "Feb 3, 2013 at 7:54pm (PST)", startAtTimeFormat: ""}},
{"startAtSequence is less than 1", natsMetadata{StartAtSequence: ptr.Of(uint64(0))}},
{"startWithLastReceived is other than true", natsMetadata{StartWithLastReceived: "foo"}},
{"deliverAll is other than true", natsMetadata{DeliverAll: "foo"}},
{"deliverNew is other than true", natsMetadata{DeliverNew: "foo"}},
{"startAtTime is empty", natsMetadata{StartAtTime: "", StartAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}},
{"startAtTimeFormat is empty", natsMetadata{StartAtTime: "Feb 3, 2013 at 7:54pm (PST)", StartAtTimeFormat: ""}},
}
for _, _test := range tests {
@ -391,7 +392,7 @@ func TestSubscriptionOptionsForInvalidOptions(t *testing.T) {
func TestSubscriptionOptions(t *testing.T) {
// general
t.Run("manual ACK option is present by default", func(t *testing.T) {
natsStreaming := natsStreamingPubSub{metadata: metadata{}}
natsStreaming := natsStreamingPubSub{metadata: natsMetadata{}}
opts, err := natsStreaming.subscriptionOptions()
assert.Empty(t, err)
assert.NotEmpty(t, opts)
@ -399,7 +400,7 @@ func TestSubscriptionOptions(t *testing.T) {
})
t.Run("only one subscription option will be honored", func(t *testing.T) {
m := metadata{deliverNew: deliverNewTrue, deliverAll: deliverAllTrue, startAtTimeDelta: 1 * time.Hour}
m := natsMetadata{DeliverNew: deliverNewTrue, DeliverAll: deliverAllTrue, StartAtTimeDelta: 1 * time.Hour}
natsStreaming := natsStreamingPubSub{metadata: m}
opts, err := natsStreaming.subscriptionOptions()
assert.Empty(t, err)
@ -410,7 +411,7 @@ func TestSubscriptionOptions(t *testing.T) {
// invalid subscription options
t.Run("startAtTime is invalid", func(t *testing.T) {
m := metadata{startAtTime: "foobar", startAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}
m := natsMetadata{StartAtTime: "foobar", StartAtTimeFormat: "Jan 2, 2006 at 3:04pm (MST)"}
natsStreaming := natsStreamingPubSub{metadata: m}
opts, err := natsStreaming.subscriptionOptions()
assert.NotEmpty(t, err)
@ -418,7 +419,7 @@ func TestSubscriptionOptions(t *testing.T) {
})
t.Run("startAtTimeFormat is invalid", func(t *testing.T) {
m := metadata{startAtTime: "Feb 3, 2013 at 7:54pm (PST)", startAtTimeFormat: "foo"}
m := natsMetadata{StartAtTime: "Feb 3, 2013 at 7:54pm (PST)", StartAtTimeFormat: "foo"}
natsStreaming := natsStreamingPubSub{metadata: m}
opts, err := natsStreaming.subscriptionOptions()

View File

@ -27,6 +27,7 @@ type PubSub interface {
Publish(ctx context.Context, req *PublishRequest) error
Subscribe(ctx context.Context, req SubscribeRequest, handler Handler) error
Close() error
GetComponentMetadata() map[string]string
}
// BulkPublisher is the interface that wraps the BulkPublish method.

View File

@ -16,19 +16,19 @@ package pulsar
import "time"
type pulsarMetadata struct {
Host string `json:"host"`
ConsumerID string `json:"consumerID"`
EnableTLS bool `json:"enableTLS"`
DisableBatching bool `json:"disableBatching"`
BatchingMaxPublishDelay time.Duration `json:"batchingMaxPublishDelay"`
BatchingMaxSize uint `json:"batchingMaxSize"`
BatchingMaxMessages uint `json:"batchingMaxMessages"`
Tenant string `json:"tenant"`
Namespace string `json:"namespace"`
Persistent bool `json:"persistent"`
Token string `json:"token"`
RedeliveryDelay time.Duration `json:"redeliveryDelay"`
topicSchemas map[string]schemaMetadata
Host string `mapstructure:"host"`
ConsumerID string `mapstructure:"consumerID"`
EnableTLS bool `mapstructure:"enableTLS"`
DisableBatching bool `mapstructure:"disableBatching"`
BatchingMaxPublishDelay time.Duration `mapstructure:"batchingMaxPublishDelay"`
BatchingMaxSize uint `mapstructure:"batchingMaxSize"`
BatchingMaxMessages uint `mapstructure:"batchingMaxMessages"`
Tenant string `mapstructure:"tenant"`
Namespace string `mapstructure:"namespace"`
Persistent bool `mapstructure:"persistent"`
Token string `mapstructure:"token"`
RedeliveryDelay time.Duration `mapstructure:"redeliveryDelay"`
internalTopicSchemas map[string]schemaMetadata `mapstructure:"-"`
}
type schemaMetadata struct {

View File

@ -18,7 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"reflect"
"strings"
"sync"
"sync/atomic"
@ -29,6 +29,7 @@ import (
"github.com/apache/pulsar-client-go/pulsar"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -106,89 +107,36 @@ func NewPulsar(l logger.Logger) pubsub.PubSub {
}
func parsePulsarMetadata(meta pubsub.Metadata) (*pulsarMetadata, error) {
m := pulsarMetadata{Persistent: true, Tenant: defaultTenant, Namespace: defaultNamespace, topicSchemas: map[string]schemaMetadata{}}
m.ConsumerID = meta.Properties[consumerID]
m := pulsarMetadata{
Persistent: true,
Tenant: defaultTenant,
Namespace: defaultNamespace,
internalTopicSchemas: map[string]schemaMetadata{},
DisableBatching: false,
BatchingMaxPublishDelay: defaultBatchingMaxPublishDelay,
BatchingMaxMessages: defaultMaxMessages,
BatchingMaxSize: defaultMaxBatchSize,
RedeliveryDelay: defaultRedeliveryDelay,
}
if val, ok := meta.Properties[host]; ok && val != "" {
m.Host = val
} else {
if err := metadata.DecodeMetadata(meta.Properties, &m); err != nil {
return nil, err
}
if m.Host == "" {
return nil, errors.New("pulsar error: missing pulsar host")
}
if val, ok := meta.Properties[enableTLS]; ok && val != "" {
tls, err := strconv.ParseBool(val)
if err != nil {
return nil, errors.New("pulsar error: invalid value for enableTLS")
}
m.EnableTLS = tls
}
// DisableBatching is defaultly batching.
m.DisableBatching = false
if val, ok := meta.Properties[disableBatching]; ok {
disableBatching, err := strconv.ParseBool(val)
if err != nil {
return nil, errors.New("pulsar error: invalid value for disableBatching")
}
m.DisableBatching = disableBatching
}
m.BatchingMaxPublishDelay = defaultBatchingMaxPublishDelay
if val, ok := meta.Properties[batchingMaxPublishDelay]; ok {
batchingMaxPublishDelay, err := formatDuration(val)
if err != nil {
return nil, errors.New("pulsar error: invalid value for batchingMaxPublishDelay")
}
m.BatchingMaxPublishDelay = batchingMaxPublishDelay
}
m.BatchingMaxMessages = defaultMaxMessages
if val, ok := meta.Properties[batchingMaxMessages]; ok {
batchingMaxMessages, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return nil, errors.New("pulsar error: invalid value for batchingMaxMessages")
}
m.BatchingMaxMessages = uint(batchingMaxMessages)
}
m.BatchingMaxSize = defaultMaxBatchSize
if val, ok := meta.Properties[batchingMaxSize]; ok {
batchingMaxSize, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return nil, errors.New("pulsar error: invalid value for batchingMaxSize")
}
m.BatchingMaxSize = uint(batchingMaxSize)
}
m.RedeliveryDelay = defaultRedeliveryDelay
if val, ok := meta.Properties[redeliveryDelay]; ok {
redeliveryDelay, err := formatDuration(val)
if err != nil {
return nil, errors.New("pulsar error: invalid value for redeliveryDelay")
}
m.RedeliveryDelay = redeliveryDelay
}
if val, ok := meta.Properties[persistent]; ok && val != "" {
per, err := strconv.ParseBool(val)
if err != nil {
return nil, errors.New("pulsar error: invalid value for persistent")
}
m.Persistent = per
}
if val, ok := meta.Properties[tenant]; ok && val != "" {
m.Tenant = val
}
if val, ok := meta.Properties[namespace]; ok && val != "" {
m.Namespace = val
}
if val, ok := meta.Properties[pulsarToken]; ok && val != "" {
m.Token = val
}
for k, v := range meta.Properties {
if strings.HasSuffix(k, topicJSONSchemaIdentifier) {
topic := k[:len(k)-len(topicJSONSchemaIdentifier)]
m.topicSchemas[topic] = schemaMetadata{
m.internalTopicSchemas[topic] = schemaMetadata{
protocol: jsonProtocol,
value: v,
}
} else if strings.HasSuffix(k, topicAvroSchemaIdentifier) {
topic := k[:len(k)-len(topicJSONSchemaIdentifier)]
m.topicSchemas[topic] = schemaMetadata{
m.internalTopicSchemas[topic] = schemaMetadata{
protocol: avroProtocol,
value: v,
}
@ -253,7 +201,7 @@ func (p *Pulsar) Publish(ctx context.Context, req *pubsub.PublishRequest) error
topic := p.formatTopic(req.Topic)
producer, ok := p.cache.Get(topic)
sm, hasSchema := p.metadata.topicSchemas[req.Topic]
sm, hasSchema := p.metadata.internalTopicSchemas[req.Topic]
if !ok || producer == nil {
p.logger.Debugf("creating producer for topic %s, full topic name in pulsar is %s", req.Topic, topic)
@ -401,7 +349,7 @@ func (p *Pulsar) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
NackRedeliveryDelay: p.metadata.RedeliveryDelay,
}
if sm, ok := p.metadata.topicSchemas[req.Topic]; ok {
if sm, ok := p.metadata.internalTopicSchemas[req.Topic]; ok {
options.Schema = getPulsarSchema(sm)
}
consumer, err := p.client.Subscribe(options)
@ -510,13 +458,10 @@ func (p *Pulsar) formatTopic(topic string) string {
return fmt.Sprintf(topicFormat, persist, p.metadata.Tenant, p.metadata.Namespace, topic)
}
func formatDuration(durationString string) (time.Duration, error) {
if val, err := strconv.Atoi(durationString); err == nil {
return time.Duration(val) * time.Millisecond, nil
}
// Convert it by parsing
d, err := time.ParseDuration(durationString)
return d, err
// GetComponentMetadata returns the metadata of the component.
func (p *Pulsar) GetComponentMetadata() map[string]string {
metadataStruct := pulsarMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -44,7 +44,7 @@ func TestParsePulsarMetadata(t *testing.T) {
assert.Equal(t, 5*time.Second, meta.BatchingMaxPublishDelay)
assert.Equal(t, uint(100), meta.BatchingMaxSize)
assert.Equal(t, uint(200), meta.BatchingMaxMessages)
assert.Empty(t, meta.topicSchemas)
assert.Empty(t, meta.internalTopicSchemas)
}
func TestParsePulsarSchemaMetadata(t *testing.T) {
@ -59,9 +59,9 @@ func TestParsePulsarSchemaMetadata(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "a", meta.Host)
assert.Len(t, meta.topicSchemas, 2)
assert.Equal(t, "1", meta.topicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.topicSchemas["kenobi.jsonschema"].value)
assert.Len(t, meta.internalTopicSchemas, 2)
assert.Equal(t, "1", meta.internalTopicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.internalTopicSchemas["kenobi.jsonschema"].value)
})
t.Run("test avro", func(t *testing.T) {
@ -75,9 +75,9 @@ func TestParsePulsarSchemaMetadata(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "a", meta.Host)
assert.Len(t, meta.topicSchemas, 2)
assert.Equal(t, "1", meta.topicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.topicSchemas["kenobi.avroschema"].value)
assert.Len(t, meta.internalTopicSchemas, 2)
assert.Equal(t, "1", meta.internalTopicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.internalTopicSchemas["kenobi.avroschema"].value)
})
t.Run("test combined avro/json", func(t *testing.T) {
@ -91,11 +91,11 @@ func TestParsePulsarSchemaMetadata(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "a", meta.Host)
assert.Len(t, meta.topicSchemas, 2)
assert.Equal(t, "1", meta.topicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.topicSchemas["kenobi"].value)
assert.Equal(t, avroProtocol, meta.topicSchemas["obiwan"].protocol)
assert.Equal(t, jsonProtocol, meta.topicSchemas["kenobi"].protocol)
assert.Len(t, meta.internalTopicSchemas, 2)
assert.Equal(t, "1", meta.internalTopicSchemas["obiwan"].value)
assert.Equal(t, "2", meta.internalTopicSchemas["kenobi"].value)
assert.Equal(t, avroProtocol, meta.internalTopicSchemas["obiwan"].protocol)
assert.Equal(t, jsonProtocol, meta.internalTopicSchemas["kenobi"].protocol)
})
t.Run("test funky edge case", func(t *testing.T) {
@ -108,8 +108,8 @@ func TestParsePulsarSchemaMetadata(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "a", meta.Host)
assert.Len(t, meta.topicSchemas, 1)
assert.Equal(t, "1", meta.topicSchemas["obiwan.jsonschema"].value)
assert.Len(t, meta.internalTopicSchemas, 1)
assert.Equal(t, "1", meta.internalTopicSchemas["obiwan.jsonschema"].value)
})
}
@ -158,14 +158,14 @@ func TestMissingHost(t *testing.T) {
assert.Equal(t, "pulsar error: missing pulsar host", err.Error())
}
func TestInvalidTLSInput(t *testing.T) {
func TestInvalidTLSInputDefaultsToFalse(t *testing.T) {
m := pubsub.Metadata{}
m.Properties = map[string]string{"host": "a", "enableTLS": "honk"}
meta, err := parsePulsarMetadata(m)
assert.Error(t, err)
assert.Nil(t, meta)
assert.Equal(t, "pulsar error: invalid value for enableTLS", err.Error())
assert.NoError(t, err)
assert.NotNil(t, meta)
assert.False(t, meta.EnableTLS)
}
func TestValidTenantAndNS(t *testing.T) {

View File

@ -16,42 +16,40 @@ package rabbitmq
import (
"fmt"
"net/url"
"strconv"
"time"
"github.com/dapr/components-contrib/internal/utils"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/dapr/kit/logger"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
)
type metadata struct {
pubsub.TLSProperties
consumerID string
connectionString string
protocol string
hostname string
username string
password string
durable bool
enableDeadLetter bool
deleteWhenUnused bool
autoAck bool
requeueInFailure bool
deliveryMode uint8 // Transient (0 or 1) or Persistent (2)
prefetchCount uint8 // Prefetch deactivated if 0
reconnectWait time.Duration
maxLen int64
maxLenBytes int64
exchangeKind string
publisherConfirm bool
saslExternal bool
concurrency pubsub.ConcurrencyMode
defaultQueueTTL *time.Duration
type rabbitmqMetadata struct {
pubsub.TLSProperties `mapstructure:",squash"`
ConsumerID string `mapstructure:"consumerID"`
ConnectionString string `mapstructure:"connectionString"`
Protocol string `mapstructure:"protocol"`
internalProtocol string `mapstructure:"-"`
Hostname string `mapstructure:"hostname"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Durable bool `mapstructure:"durable"`
EnableDeadLetter bool `mapstructure:"enableDeadLetter"`
DeleteWhenUnused bool `mapstructure:"deletedWhenUnused"`
AutoAck bool `mapstructure:"autoAck"`
RequeueInFailure bool `mapstructure:"requeueInFailure"`
DeliveryMode uint8 `mapstructure:"deliveryMode"` // Transient (0 or 1) or Persistent (2)
PrefetchCount uint8 `mapstructure:"prefetchCount"` // Prefetch deactivated if 0
ReconnectWait time.Duration `mapstructure:"reconnectWaitSeconds"`
MaxLen int64 `mapstructure:"maxLen"`
MaxLenBytes int64 `mapstructure:"maxLenBytes"`
ExchangeKind string `mapstructure:"exchangeKind"`
PublisherConfirm bool `mapstructure:"publisherConfirm"`
SaslExternal bool `mapstructure:"saslExternal"`
Concurrency pubsub.ConcurrencyMode `mapstructure:"concurrency"`
DefaultQueueTTL *time.Duration `mapstructure:"ttlInSeconds"`
}
const (
@ -87,141 +85,62 @@ const (
)
// createMetadata creates a new instance from the pubsub metadata.
func createMetadata(pubSubMetadata pubsub.Metadata, log logger.Logger) (*metadata, error) {
result := metadata{
protocol: protocolAMQP,
hostname: "localhost",
durable: true,
deleteWhenUnused: true,
autoAck: false,
reconnectWait: time.Duration(defaultReconnectWaitSeconds) * time.Second,
exchangeKind: fanoutExchangeKind,
publisherConfirm: false,
saslExternal: false,
func createMetadata(pubSubMetadata pubsub.Metadata, log logger.Logger) (*rabbitmqMetadata, error) {
result := rabbitmqMetadata{
internalProtocol: protocolAMQP,
Hostname: "localhost",
Durable: true,
DeleteWhenUnused: true,
AutoAck: false,
ReconnectWait: time.Duration(defaultReconnectWaitSeconds) * time.Second,
ExchangeKind: fanoutExchangeKind,
PublisherConfirm: false,
SaslExternal: false,
}
if val, found := pubSubMetadata.Properties[metadataConnectionStringKey]; found && val != "" {
result.connectionString = val
} else if val, found := pubSubMetadata.Properties[metadataHostKey]; found && val != "" {
result.connectionString = val
log.Warn("[DEPRECATION NOTICE] The 'host' argument is deprecated. Use 'connectionString' or individual connection arguments instead: https://docs.dapr.io/reference/components-reference/supported-pubsub/setup-rabbitmq/")
// upgrade metadata
if val, found := pubSubMetadata.Properties[metadataConnectionStringKey]; !found || val == "" {
if host, found := pubSubMetadata.Properties[metadataHostKey]; found && host != "" {
pubSubMetadata.Properties[metadataConnectionStringKey] = host
log.Warn("[DEPRECATION NOTICE] The 'host' argument is deprecated. Use 'connectionString' or individual connection arguments instead: https://docs.dapr.io/reference/components-reference/supported-pubsub/setup-rabbitmq/")
}
}
if result.connectionString != "" {
uri, err := amqp.ParseURI(result.connectionString)
if err := metadata.DecodeMetadata(pubSubMetadata.Properties, &result); err != nil {
return nil, err
}
if result.ConnectionString != "" {
uri, err := amqp.ParseURI(result.ConnectionString)
if err != nil {
return &result, fmt.Errorf("%s invalid connection string: %s, err: %w", errorMessagePrefix, result.connectionString, err)
return &result, fmt.Errorf("%s invalid connection string: %s, err: %w", errorMessagePrefix, result.ConnectionString, err)
}
result.protocol = uri.Scheme
result.internalProtocol = uri.Scheme
}
if val, found := pubSubMetadata.Properties[metadataProtocolKey]; found && val != "" {
if result.connectionString != "" && result.protocol != val {
return &result, fmt.Errorf("%s protocol does not match connection string, protocol: %s, connection string: %s", errorMessagePrefix, val, result.connectionString)
if result.Protocol != "" {
if result.ConnectionString != "" && result.internalProtocol != result.Protocol {
return &result, fmt.Errorf("%s protocol does not match connection string, protocol: %s, connection string: %s", errorMessagePrefix, result.Protocol, result.ConnectionString)
}
result.protocol = val
result.internalProtocol = result.Protocol
}
if val, found := pubSubMetadata.Properties[metadataHostnameKey]; found && val != "" {
result.hostname = val
if result.DeliveryMode > 2 {
return &result, fmt.Errorf("%s invalid RabbitMQ delivery mode, accepted values are between 0 and 2", errorMessagePrefix)
}
if val, found := pubSubMetadata.Properties[metadataUsernameKey]; found && val != "" {
result.username = val
if !exchangeKindValid(result.ExchangeKind) {
return &result, fmt.Errorf("%s invalid RabbitMQ exchange kind %s", errorMessagePrefix, result.ExchangeKind)
}
if val, found := pubSubMetadata.Properties[metadataPasswordKey]; found && val != "" {
result.password = val
}
if val, found := pubSubMetadata.Properties[metadataConsumerIDKey]; found && val != "" {
result.consumerID = val
}
if val, found := pubSubMetadata.Properties[metadataDeliveryModeKey]; found && val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
if intVal < 0 || intVal > 2 {
return &result, fmt.Errorf("%s invalid RabbitMQ delivery mode, accepted values are between 0 and 2", errorMessagePrefix)
}
result.deliveryMode = uint8(intVal)
}
}
if val, found := pubSubMetadata.Properties[metadataDurableKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.durable = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataEnableDeadLetterKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.enableDeadLetter = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataDeleteWhenUnusedKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.deleteWhenUnused = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataAutoAckKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.autoAck = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataRequeueInFailureKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.requeueInFailure = boolVal
}
}
if val, found := pubSubMetadata.Properties[metadataReconnectWaitSecondsKey]; found && val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
result.reconnectWait = time.Duration(intVal) * time.Second
}
}
if val, found := pubSubMetadata.Properties[metadataPrefetchCountKey]; found && val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
result.prefetchCount = uint8(intVal)
}
}
if val, found := pubSubMetadata.Properties[metadataMaxLenKey]; found && val != "" {
if intVal, err := strconv.ParseInt(val, 10, 64); err == nil {
result.maxLen = intVal
}
}
if val, found := pubSubMetadata.Properties[metadataMaxLenBytesKey]; found && val != "" {
if intVal, err := strconv.ParseInt(val, 10, 64); err == nil {
result.maxLenBytes = intVal
}
}
if val, found := pubSubMetadata.Properties[metadataExchangeKindKey]; found && val != "" {
if exchangeKindValid(val) {
result.exchangeKind = val
} else {
return &result, fmt.Errorf("%s invalid RabbitMQ exchange kind %s", errorMessagePrefix, val)
}
}
if val, found := pubSubMetadata.Properties[metadataPublisherConfirmKey]; found && val != "" {
if boolVal, err := strconv.ParseBool(val); err == nil {
result.publisherConfirm = boolVal
}
}
ttl, ok, err := contribMetadata.TryGetTTL(pubSubMetadata.Properties)
ttl, ok, err := metadata.TryGetTTL(pubSubMetadata.Properties)
if err != nil {
return &result, fmt.Errorf("%s parse RabbitMQ ttl metadata with error: %s", errorMessagePrefix, err)
}
if ok {
result.defaultQueueTTL = &ttl
result.DefaultQueueTTL = &ttl
}
result.TLSProperties, err = pubsub.TLS(pubSubMetadata.Properties)
@ -229,32 +148,23 @@ func createMetadata(pubSubMetadata pubsub.Metadata, log logger.Logger) (*metadat
return &result, fmt.Errorf("%s invalid TLS configuration: %w", errorMessagePrefix, err)
}
if val, found := pubSubMetadata.Properties[metadataSaslExternal]; found && val != "" {
boolVal := utils.IsTruthy(val)
if boolVal && (result.TLSProperties.CACert == "" || result.TLSProperties.ClientCert == "" || result.TLSProperties.ClientKey == "") {
return &result, fmt.Errorf("%s can only be set to true, when all these properties are set: %s, %s, %s", metadataSaslExternal, pubsub.CACert, pubsub.ClientCert, pubsub.ClientKey)
}
result.saslExternal = boolVal
if result.SaslExternal && (result.TLSProperties.CACert == "" || result.TLSProperties.ClientCert == "" || result.TLSProperties.ClientKey == "") {
return &result, fmt.Errorf("%s can only be set to true, when all these properties are set: %s, %s, %s", metadataSaslExternal, pubsub.CACert, pubsub.ClientCert, pubsub.ClientKey)
}
c, err := pubsub.Concurrency(pubSubMetadata.Properties)
if err != nil {
return &result, err
}
result.concurrency = c
return &result, nil
result.Concurrency, err = pubsub.Concurrency(pubSubMetadata.Properties)
return &result, err
}
func (m *metadata) formatQueueDeclareArgs(origin amqp.Table) amqp.Table {
func (m *rabbitmqMetadata) formatQueueDeclareArgs(origin amqp.Table) amqp.Table {
if origin == nil {
origin = amqp.Table{}
}
if m.maxLen > 0 {
origin[argMaxLength] = m.maxLen
if m.MaxLen > 0 {
origin[argMaxLength] = m.MaxLen
}
if m.maxLenBytes > 0 {
origin[argMaxLengthBytes] = m.maxLenBytes
if m.MaxLenBytes > 0 {
origin[argMaxLengthBytes] = m.MaxLenBytes
}
return origin
@ -264,20 +174,20 @@ func exchangeKindValid(kind string) bool {
return kind == amqp.ExchangeFanout || kind == amqp.ExchangeTopic || kind == amqp.ExchangeDirect || kind == amqp.ExchangeHeaders
}
func (m *metadata) connectionURI() string {
if m.connectionString != "" {
return m.connectionString
func (m *rabbitmqMetadata) connectionURI() string {
if m.ConnectionString != "" {
return m.ConnectionString
}
u := url.URL{
Scheme: m.protocol,
Host: m.hostname,
Scheme: m.internalProtocol,
Host: m.Hostname,
}
if m.username != "" && m.password != "" {
u.User = url.UserPassword(m.username, m.password)
} else if m.username != "" {
u.User = url.User(m.username)
if m.Username != "" && m.Password != "" {
u.User = url.UserPassword(m.Username, m.Password)
} else if m.Username != "" {
u.User = url.User(m.Username)
}
return u.String()

View File

@ -17,6 +17,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"testing"
amqp "github.com/rabbitmq/amqp091-go"
@ -76,25 +77,25 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataConnectionStringKey], m.connectionString)
assert.Equal(t, fakeProperties[metadataProtocolKey], m.protocol)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataUsernameKey], m.username)
assert.Equal(t, fakeProperties[metadataPasswordKey], m.password)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, false, m.autoAck)
assert.Equal(t, false, m.requeueInFailure)
assert.Equal(t, true, m.deleteWhenUnused)
assert.Equal(t, false, m.enableDeadLetter)
assert.Equal(t, false, m.publisherConfirm)
assert.Equal(t, uint8(0), m.deliveryMode)
assert.Equal(t, uint8(0), m.prefetchCount)
assert.Equal(t, int64(0), m.maxLen)
assert.Equal(t, int64(0), m.maxLenBytes)
assert.Equal(t, fakeProperties[metadataConnectionStringKey], m.ConnectionString)
assert.Equal(t, fakeProperties[metadataProtocolKey], m.internalProtocol)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataUsernameKey], m.Username)
assert.Equal(t, fakeProperties[metadataPasswordKey], m.Password)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, false, m.AutoAck)
assert.Equal(t, false, m.RequeueInFailure)
assert.Equal(t, true, m.DeleteWhenUnused)
assert.Equal(t, false, m.EnableDeadLetter)
assert.Equal(t, false, m.PublisherConfirm)
assert.Equal(t, uint8(0), m.DeliveryMode)
assert.Equal(t, uint8(0), m.PrefetchCount)
assert.Equal(t, int64(0), m.MaxLen)
assert.Equal(t, int64(0), m.MaxLenBytes)
assert.Equal(t, "", m.ClientKey)
assert.Equal(t, "", m.ClientCert)
assert.Equal(t, "", m.CACert)
assert.Equal(t, fanoutExchangeKind, m.exchangeKind)
assert.Equal(t, fanoutExchangeKind, m.ExchangeKind)
})
invalidDeliveryModes := []string{"3", "10", "-1"}
@ -112,10 +113,12 @@ func TestCreateMetadata(t *testing.T) {
m, err := createMetadata(fakeMetaData, log)
// assert
assert.EqualError(t, err, "rabbitmq pub/sub error: invalid RabbitMQ delivery mode, accepted values are between 0 and 2")
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, uint8(0), m.deliveryMode)
assert.True(t, strings.Contains(err.Error(), "rabbitmq pub/sub error: invalid RabbitMQ delivery mode, accepted values are between 0 and 2") ||
strings.Contains(err.Error(), "'deliveryMode'"))
if deliveryMode != "-1" {
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
}
})
}
@ -132,9 +135,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, uint8(2), m.deliveryMode)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, uint8(2), m.DeliveryMode)
})
t.Run("protocol does not match connection string", func(t *testing.T) {
@ -168,7 +171,7 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.Nil(t, err)
assert.Equal(t, fakeProperties[metadataProtocolKey], m.protocol)
assert.Equal(t, fakeProperties[metadataProtocolKey], m.internalProtocol)
})
t.Run("invalid concurrency", func(t *testing.T) {
@ -199,9 +202,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, uint8(1), m.prefetchCount)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, uint8(1), m.PrefetchCount)
})
t.Run("tls related properties are set", func(t *testing.T) {
@ -249,10 +252,10 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, int64(1), m.maxLen)
assert.Equal(t, int64(2000000), m.maxLenBytes)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, int64(1), m.MaxLen)
assert.Equal(t, int64(2000000), m.MaxLenBytes)
})
for _, tt := range booleanFlagTests {
@ -269,9 +272,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.autoAck)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.AutoAck)
})
}
@ -289,9 +292,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.requeueInFailure)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.RequeueInFailure)
})
}
@ -309,9 +312,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.deleteWhenUnused)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.DeleteWhenUnused)
})
}
@ -329,9 +332,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.durable)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.Durable)
})
}
@ -349,9 +352,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.publisherConfirm)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.PublisherConfirm)
})
}
@ -369,9 +372,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, tt.expected, m.enableDeadLetter)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, tt.expected, m.EnableDeadLetter)
})
}
validExchangeKind := []string{amqp.ExchangeDirect, amqp.ExchangeTopic, amqp.ExchangeFanout, amqp.ExchangeHeaders}
@ -390,9 +393,9 @@ func TestCreateMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID)
assert.Equal(t, exchangeKind, m.exchangeKind)
assert.Equal(t, fakeProperties[metadataHostnameKey], m.Hostname)
assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.ConsumerID)
assert.Equal(t, exchangeKind, m.ExchangeKind)
})
}

View File

@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"sync"
@ -27,7 +28,7 @@ import (
amqp "github.com/rabbitmq/amqp091-go"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -59,7 +60,7 @@ type rabbitMQ struct {
channel rabbitMQChannelBroker
channelMutex sync.RWMutex
connectionCount int
metadata *metadata
metadata *rabbitmqMetadata
declaredExchanges map[string]bool
connectionDial func(protocol, uri string, tlsCfg *tls.Config, externalSasl bool) (rabbitMQConnectionBroker, rabbitMQChannelBroker, error)
@ -175,14 +176,14 @@ func (r *rabbitMQ) reconnect(connectionCount int) error {
return err
}
r.connection, r.channel, err = r.connectionDial(r.metadata.protocol, r.metadata.connectionURI(), tlsCfg, r.metadata.saslExternal)
r.connection, r.channel, err = r.connectionDial(r.metadata.internalProtocol, r.metadata.connectionURI(), tlsCfg, r.metadata.SaslExternal)
if err != nil {
r.reset()
return err
}
if r.metadata.publisherConfirm {
if r.metadata.PublisherConfirm {
err = r.channel.Confirm(false)
if err != nil {
r.reset()
@ -206,7 +207,7 @@ func (r *rabbitMQ) publishSync(ctx context.Context, req *pubsub.PublishRequest)
return r.channel, r.connectionCount, errors.New(errorChannelNotInitialized)
}
if err := r.ensureExchangeDeclared(r.channel, req.Topic, r.metadata.exchangeKind); err != nil {
if err := r.ensureExchangeDeclared(r.channel, req.Topic, r.metadata.ExchangeKind); err != nil {
r.logger.Errorf("%s publishing to %s failed in ensureExchangeDeclared: %v", logMessagePrefix, req.Topic, err)
return r.channel, r.connectionCount, err
@ -216,7 +217,7 @@ func (r *rabbitMQ) publishSync(ctx context.Context, req *pubsub.PublishRequest)
routingKey = val
}
ttl, ok, err := contribMetadata.TryGetTTL(req.Metadata)
ttl, ok, err := metadata.TryGetTTL(req.Metadata)
if err != nil {
r.logger.Warnf("%s publishing to %s failed to parse TryGetTTL: %v, it is ignored.", logMessagePrefix, req.Topic, err)
}
@ -224,18 +225,18 @@ func (r *rabbitMQ) publishSync(ctx context.Context, req *pubsub.PublishRequest)
if ok {
// RabbitMQ expects the duration in ms
expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
} else if r.metadata.defaultQueueTTL != nil {
expiration = strconv.FormatInt(r.metadata.defaultQueueTTL.Milliseconds(), 10)
} else if r.metadata.DefaultQueueTTL != nil {
expiration = strconv.FormatInt(r.metadata.DefaultQueueTTL.Milliseconds(), 10)
}
p := amqp.Publishing{
ContentType: "text/plain",
Body: req.Data,
DeliveryMode: r.metadata.deliveryMode,
DeliveryMode: r.metadata.DeliveryMode,
Expiration: expiration,
}
priority, ok, err := contribMetadata.TryGetPriority(req.Metadata)
priority, ok, err := metadata.TryGetPriority(req.Metadata)
if err != nil {
r.logger.Warnf("%s publishing to %s failed to parse priority: %v, it is ignored.", logMessagePrefix, req.Topic, err)
}
@ -283,9 +284,9 @@ func (r *rabbitMQ) Publish(ctx context.Context, req *pubsub.PublishRequest) erro
return err
}
if mustReconnect(channel, err) {
r.logger.Warnf("%s publisher is reconnecting in %s ...", logMessagePrefix, r.metadata.reconnectWait.String())
r.logger.Warnf("%s publisher is reconnecting in %s ...", logMessagePrefix, r.metadata.ReconnectWait.String())
select {
case <-time.After(r.metadata.reconnectWait):
case <-time.After(r.metadata.ReconnectWait):
case <-ctx.Done():
return nil
}
@ -307,11 +308,11 @@ func (r *rabbitMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, h
return errors.New("component is closed")
}
if r.metadata.consumerID == "" {
if r.metadata.ConsumerID == "" {
return errors.New("consumerID is required for subscriptions")
}
queueName := fmt.Sprintf("%s-%s", r.metadata.consumerID, req.Topic)
queueName := fmt.Sprintf("%s-%s", r.metadata.ConsumerID, req.Topic)
r.logger.Infof("%s subscribe to topic/queue '%s/%s'", logMessagePrefix, req.Topic, queueName)
// Do not set a timeout on the context, as we're just waiting for the first ack; we're using a semaphore instead
@ -344,7 +345,7 @@ func (r *rabbitMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, h
// this function call should be wrapped by channelMutex.
func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub.SubscribeRequest, queueName string) (*amqp.Queue, error) {
err := r.ensureExchangeDeclared(channel, req.Topic, r.metadata.exchangeKind)
err := r.ensureExchangeDeclared(channel, req.Topic, r.metadata.ExchangeKind)
if err != nil {
r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in ensureExchangeDeclared: %v", logMessagePrefix, req.Topic, queueName, err)
@ -353,7 +354,7 @@ func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub
r.logger.Infof("%s declaring queue '%s'", logMessagePrefix, queueName)
var args amqp.Table
if r.metadata.enableDeadLetter {
if r.metadata.EnableDeadLetter {
// declare dead letter exchange
dlxName := fmt.Sprintf(defaultDeadLetterExchangeFormat, queueName)
dlqName := fmt.Sprintf(defaultDeadLetterQueueFormat, queueName)
@ -367,7 +368,7 @@ func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub
dlqArgs := r.metadata.formatQueueDeclareArgs(nil)
// dead letter queue use lazy mode, keeping as many messages as possible on disk to reduce RAM usage
dlqArgs[argQueueMode] = queueModeLazy
q, err = channel.QueueDeclare(dlqName, true, r.metadata.deleteWhenUnused, false, false, dlqArgs)
q, err = channel.QueueDeclare(dlqName, true, r.metadata.DeleteWhenUnused, false, false, dlqArgs)
if err != nil {
r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in channel.QueueDeclare: %v", logMessagePrefix, req.Topic, dlqName, err)
@ -400,16 +401,16 @@ func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub
args[argMaxPriority] = mp
}
q, err := channel.QueueDeclare(queueName, r.metadata.durable, r.metadata.deleteWhenUnused, false, false, args)
q, err := channel.QueueDeclare(queueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, args)
if err != nil {
r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in channel.QueueDeclare: %v", logMessagePrefix, req.Topic, queueName, err)
return nil, err
}
if r.metadata.prefetchCount > 0 {
r.logger.Infof("%s setting prefetch count to %s", logMessagePrefix, strconv.Itoa(int(r.metadata.prefetchCount)))
err = channel.Qos(int(r.metadata.prefetchCount), 0, false)
if r.metadata.PrefetchCount > 0 {
r.logger.Infof("%s setting prefetch count to %s", logMessagePrefix, strconv.Itoa(int(r.metadata.PrefetchCount)))
err = channel.Qos(int(r.metadata.PrefetchCount), 0, false)
if err != nil {
r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in channel.Qos: %v", logMessagePrefix, req.Topic, queueName, err)
@ -469,7 +470,7 @@ func (r *rabbitMQ) subscribeForever(ctx context.Context, req pubsub.SubscribeReq
msgs, err = channel.Consume(
q.Name,
queueName, // consumerId
r.metadata.autoAck, // autoAck
r.metadata.AutoAck, // autoAck
false, // exclusive
false, // noLocal
false, // noWait
@ -510,9 +511,9 @@ func (r *rabbitMQ) subscribeForever(ctx context.Context, req pubsub.SubscribeReq
}
if mustReconnect(channel, err) {
r.logger.Warnf("%s subscriber is reconnecting in %s ...", logMessagePrefix, r.metadata.reconnectWait.String())
r.logger.Warnf("%s subscriber is reconnecting in %s ...", logMessagePrefix, r.metadata.ReconnectWait.String())
select {
case <-time.After(r.metadata.reconnectWait):
case <-time.After(r.metadata.ReconnectWait):
case <-ctx.Done():
r.logger.Infof("%s subscription for %s has context canceled", logMessagePrefix, queueName)
return
@ -535,7 +536,7 @@ func (r *rabbitMQ) listenMessages(ctx context.Context, channel rabbitMQChannelBr
return nil
}
switch r.metadata.concurrency {
switch r.metadata.Concurrency {
case pubsub.Single:
err = r.handleMessage(ctx, d, topic, handler)
if err != nil && mustReconnect(channel, err) {
@ -565,14 +566,14 @@ func (r *rabbitMQ) handleMessage(ctx context.Context, d amqp.Delivery, topic str
if err != nil {
r.logger.Errorf("%s handling message from topic '%s', %s", errorMessagePrefix, topic, err)
if !r.metadata.autoAck {
if !r.metadata.AutoAck {
// if message is not auto acked we need to ack/nack
r.logger.Debugf("%s nacking message '%s' from topic '%s', requeue=%t", logMessagePrefix, d.MessageId, topic, r.metadata.requeueInFailure)
if err = d.Nack(false, r.metadata.requeueInFailure); err != nil {
r.logger.Debugf("%s nacking message '%s' from topic '%s', requeue=%t", logMessagePrefix, d.MessageId, topic, r.metadata.RequeueInFailure)
if err = d.Nack(false, r.metadata.RequeueInFailure); err != nil {
r.logger.Errorf("%s error nacking message '%s' from topic '%s', %s", logMessagePrefix, d.MessageId, topic, err)
}
}
} else if !r.metadata.autoAck {
} else if !r.metadata.AutoAck {
// if message is not auto acked we need to ack/nack
r.logger.Debugf("%s acking message '%s' from topic '%s'", logMessagePrefix, d.MessageId, topic)
if err = d.Ack(false); err != nil {
@ -670,3 +671,11 @@ func mustReconnect(channel rabbitMQChannelBroker, err error) bool {
return strings.Contains(err.Error(), errorChannelConnection)
}
// GetComponentMetadata returns the metadata of the component.
func (r *rabbitMQ) GetComponentMetadata() map[string]string {
metadataStruct := rabbitmqMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -118,7 +118,7 @@ func TestConcurrencyMode(t *testing.T) {
}}
err := pubsubRabbitMQ.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency)
assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.Concurrency)
})
t.Run("single", func(t *testing.T) {
@ -133,7 +133,7 @@ func TestConcurrencyMode(t *testing.T) {
}}
err := pubsubRabbitMQ.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, pubsub.Single, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency)
assert.Equal(t, pubsub.Single, pubsubRabbitMQ.(*rabbitMQ).metadata.Concurrency)
})
t.Run("default", func(t *testing.T) {
@ -147,7 +147,7 @@ func TestConcurrencyMode(t *testing.T) {
}}
err := pubsubRabbitMQ.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency)
assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.Concurrency)
})
}

View File

@ -17,12 +17,14 @@ import (
"context"
"errors"
"fmt"
"reflect"
"strconv"
"sync"
"sync/atomic"
"time"
rediscomponent "github.com/dapr/components-contrib/internal/component/redis"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -466,3 +468,10 @@ func (r *redisStreams) Ping(ctx context.Context) error {
return nil
}
func (r *redisStreams) GetComponentMetadata() map[string]string {
metadataStruct := metadata{}
metadataInfo := map[string]string{}
contribMetadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, contribMetadata.PubSubType)
return metadataInfo
}

View File

@ -18,6 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
@ -31,6 +32,7 @@ import (
"github.com/apache/rocketmq-client-go/v2/rlog"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -526,3 +528,11 @@ func (r *rocketMQ) Close() error {
return nil
}
// GetComponentMetadata returns the metadata of the component.
func (r *rocketMQ) GetComponentMetadata() map[string]string {
metadataStruct := rocketMQMetaData{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
return metadataInfo
}

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"sync"
@ -28,6 +29,7 @@ import (
amqp "github.com/Azure/go-amqp"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -235,15 +237,15 @@ func (a *amqpPubSub) subscribeForever(ctx context.Context, receiver *amqp.Receiv
// Connect to the AMQP broker
func (a *amqpPubSub) connect(ctx context.Context) (*amqp.Session, error) {
uri, err := url.Parse(a.metadata.url)
uri, err := url.Parse(a.metadata.URL)
if err != nil {
return nil, err
}
clientOpts := a.createClientOptions(uri)
a.logger.Infof("Attempting to connect to %s", a.metadata.url)
client, err := amqp.Dial(a.metadata.url, &clientOpts)
a.logger.Infof("Attempting to connect to %s", a.metadata.URL)
client, err := amqp.Dial(a.metadata.URL, &clientOpts)
if err != nil {
a.logger.Fatal("Dialing AMQP server:", err)
}
@ -260,8 +262,8 @@ func (a *amqpPubSub) connect(ctx context.Context) (*amqp.Session, error) {
func (a *amqpPubSub) newTLSConfig() *tls.Config {
tlsConfig := new(tls.Config)
if a.metadata.clientCert != "" && a.metadata.clientKey != "" {
cert, err := tls.X509KeyPair([]byte(a.metadata.clientCert), []byte(a.metadata.clientKey))
if a.metadata.ClientCert != "" && a.metadata.ClientKey != "" {
cert, err := tls.X509KeyPair([]byte(a.metadata.ClientCert), []byte(a.metadata.ClientKey))
if err != nil {
a.logger.Warnf("unable to load client certificate and key pair. Err: %v", err)
@ -270,9 +272,9 @@ func (a *amqpPubSub) newTLSConfig() *tls.Config {
tlsConfig.Certificates = []tls.Certificate{cert}
}
if a.metadata.caCert != "" {
if a.metadata.CaCert != "" {
tlsConfig.RootCAs = x509.NewCertPool()
if ok := tlsConfig.RootCAs.AppendCertsFromPEM([]byte(a.metadata.caCert)); !ok {
if ok := tlsConfig.RootCAs.AppendCertsFromPEM([]byte(a.metadata.CaCert)); !ok {
a.logger.Warnf("unable to load ca certificate.")
}
}
@ -287,13 +289,13 @@ func (a *amqpPubSub) createClientOptions(uri *url.URL) amqp.ConnOptions {
switch scheme {
case "amqp":
if a.metadata.anonymous {
if a.metadata.Anonymous {
opts.SASLType = amqp.SASLTypeAnonymous()
} else {
opts.SASLType = amqp.SASLTypePlain(a.metadata.username, a.metadata.password)
opts.SASLType = amqp.SASLTypePlain(a.metadata.Username, a.metadata.Password)
}
case "amqps":
opts.SASLType = amqp.SASLTypePlain(a.metadata.username, a.metadata.password)
opts.SASLType = amqp.SASLTypePlain(a.metadata.Username, a.metadata.Password)
opts.TLSConfig = a.newTLSConfig()
}
@ -323,3 +325,11 @@ func (a *amqpPubSub) Close() error {
func (a *amqpPubSub) Features() []pubsub.Feature {
return []pubsub.Feature{pubsub.FeatureSubscribeWildcards, pubsub.FeatureMessageTTL}
}
// GetComponentMetadata returns the metadata of the component.
func (a *amqpPubSub) GetComponentMetadata() map[string]string {
metadataStruct := metadata{}
metadataInfo := map[string]string{}
contribMetadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, contribMetadata.PubSubType)
return metadataInfo
}

View File

@ -48,7 +48,7 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.Equal(t, fakeProperties[amqpURL], m.url)
assert.Equal(t, fakeProperties[amqpURL], m.URL)
})
t.Run("url is not given", func(t *testing.T) {
@ -63,7 +63,7 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.EqualError(t, err, errors.New(errorMsgPrefix+" missing url").Error())
assert.Equal(t, fakeProperties[amqpURL], m.url)
assert.Equal(t, fakeProperties[amqpURL], m.URL)
})
t.Run("invalid ca certificate", func(t *testing.T) {
@ -84,7 +84,7 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
block, _ := pem.Decode([]byte(m.tlsCfg.caCert))
block, _ := pem.Decode([]byte(m.tlsCfg.CaCert))
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Errorf("failed to parse ca certificate from metadata. %v", err)
@ -110,7 +110,7 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
block, _ := pem.Decode([]byte(m.tlsCfg.clientCert))
block, _ := pem.Decode([]byte(m.tlsCfg.ClientCert))
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Errorf("failed to parse client certificate from metadata. %v", err)
@ -136,6 +136,6 @@ func TestParseMetadata(t *testing.T) {
// assert
assert.NoError(t, err)
assert.NotNil(t, m.tlsCfg.clientKey, "failed to parse valid client certificate key")
assert.NotNil(t, m.tlsCfg.ClientKey, "failed to parse valid client certificate key")
})
}

View File

@ -16,9 +16,9 @@ package amqp
import (
"encoding/pem"
"fmt"
"strconv"
"time"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
@ -29,17 +29,17 @@ const (
)
type metadata struct {
tlsCfg
url string
username string
password string
anonymous bool
tlsCfg `mapstructure:",squash"`
URL string
Username string
Password string
Anonymous bool
}
type tlsCfg struct {
caCert string
clientCert string
clientKey string
CaCert string
ClientCert string
ClientKey string
}
const (
@ -62,55 +62,43 @@ func isValidPEM(val string) bool {
}
func parseAMQPMetaData(md pubsub.Metadata, log logger.Logger) (*metadata, error) {
m := metadata{anonymous: false}
m := metadata{Anonymous: false}
err := contribMetadata.DecodeMetadata(md.Properties, &m)
if err != nil {
return &m, fmt.Errorf("%s %s", errorMsgPrefix, err)
}
// required configuration settings
if val, ok := md.Properties[amqpURL]; ok && val != "" {
m.url = val
} else {
if m.URL == "" {
return &m, fmt.Errorf("%s missing url", errorMsgPrefix)
}
// optional configuration settings
if val, ok := md.Properties[anonymous]; ok && val != "" {
var err error
m.anonymous, err = strconv.ParseBool(val)
if err != nil {
return &m, fmt.Errorf("%s invalid anonymous %s, %s", errorMsgPrefix, val, err)
}
}
if !m.anonymous {
if val, ok := md.Properties[username]; ok && val != "" {
m.username = val
} else {
if !m.Anonymous {
if m.Username == "" {
return &m, fmt.Errorf("%s missing username", errorMsgPrefix)
}
if val, ok := md.Properties[password]; ok && val != "" {
m.password = val
} else {
if m.Password == "" {
return &m, fmt.Errorf("%s missing username", errorMsgPrefix)
}
}
if val, ok := md.Properties[amqpCACert]; ok && val != "" {
if !isValidPEM(val) {
if m.CaCert != "" {
if !isValidPEM(m.CaCert) {
return &m, fmt.Errorf("%s invalid caCert", errorMsgPrefix)
}
m.tlsCfg.caCert = val
}
if val, ok := md.Properties[amqpClientCert]; ok && val != "" {
if !isValidPEM(val) {
if m.ClientCert != "" {
if !isValidPEM(m.ClientCert) {
return &m, fmt.Errorf("%s invalid clientCert", errorMsgPrefix)
}
m.tlsCfg.clientCert = val
}
if val, ok := md.Properties[amqpClientKey]; ok && val != "" {
if !isValidPEM(val) {
if m.ClientKey != "" {
if !isValidPEM(m.ClientKey) {
return &m, fmt.Errorf("%s invalid clientKey", errorMsgPrefix)
}
m.tlsCfg.clientKey = val
}
return &m, nil