pubsub.mqtt: support wildcard and shared subscriptions (#1882)

Fixes #1881

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2022-07-18 17:26:07 -07:00 committed by GitHub
parent 3195217f12
commit 66eee69188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 431 additions and 30 deletions

View File

@ -18,8 +18,11 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/url"
"regexp"
"strings"
"sync"
"time"
@ -41,18 +44,24 @@ type mqttPubSub struct {
consumer mqtt.Client
metadata *metadata
logger logger.Logger
topics map[string]pubsub.Handler
topics map[string]mqttPubSubSubscription
retriableErrLimit ratelimit.Limiter
subscribingLock sync.Mutex
subscribingLock sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
type mqttPubSubSubscription struct {
handler pubsub.Handler
alias string
matcher func(topic string) bool
}
// NewMQTTPubSub returns a new mqttPubSub instance.
func NewMQTTPubSub(logger logger.Logger) pubsub.PubSub {
return &mqttPubSub{
logger: logger,
subscribingLock: sync.Mutex{},
subscribingLock: sync.RWMutex{},
}
}
@ -89,7 +98,7 @@ func (m *mqttPubSub) Init(metadata pubsub.Metadata) error {
}
m.producer = p
m.topics = make(map[string]pubsub.Handler)
m.topics = make(map[string]mqttPubSubSubscription)
m.logger.Debug("mqtt message bus initialization complete")
@ -98,6 +107,10 @@ func (m *mqttPubSub) Init(metadata pubsub.Metadata) error {
// Publish the topic to mqtt pub sub.
func (m *mqttPubSub) Publish(req *pubsub.PublishRequest) error {
if req.Topic == "" {
return errors.New("topic name is empty")
}
// Note this can contain PII
// m.logger.Debugf("mqtt publishing topic %s with data: %v", req.Topic, req.Data)
m.logger.Debugf("mqtt publishing topic %s", req.Topic)
@ -132,6 +145,10 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest,
return ctxErr
}
if req.Topic == "" {
return errors.New("topic name is empty")
}
m.subscribingLock.Lock()
defer m.subscribingLock.Unlock()
@ -139,7 +156,7 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest,
m.resetSubscription()
// Add the topic then start the subscription
m.topics[req.Topic] = handler
m.addTopic(req.Topic, handler)
// Use the global context here to maintain the connection
m.startSubscription(m.ctx)
@ -236,13 +253,13 @@ func (m *mqttPubSub) onMessage(ctx context.Context) func(client mqtt.Client, mqt
// Note that if the connection drops before the message is explicitly ACK'd below, then it's automatically re-sent (assuming QoS is 1 or greater, which is the default). So we do not risk losing messages.
// Problem with this approach is that if the service crashes between the time the message is re-enqueued and when the ACK is sent, the message may be delivered twice
if !ack {
m.logger.Debugf("Re-publishing message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
m.logger.Debugf("Re-publishing message %s#%d", mqttMsg.Topic(), mqttMsg.MessageID())
publishErr := m.Publish(&pubsub.PublishRequest{
Topic: mqttMsg.Topic(),
Data: mqttMsg.Payload(),
})
if publishErr != nil {
m.logger.Errorf("Failed to re-publish message %s/%d. Error: %v", mqttMsg.Topic(), mqttMsg.MessageID(), publishErr)
m.logger.Errorf("Failed to re-publish message %s#%d. Error: %v", mqttMsg.Topic(), mqttMsg.MessageID(), publishErr)
// Return so Ack() isn't invoked
return
}
@ -263,24 +280,48 @@ func (m *mqttPubSub) onMessage(ctx context.Context) func(client mqtt.Client, mqt
Data: mqttMsg.Payload(),
}
topicHandler, ok := m.topics[msg.Topic]
if !ok || topicHandler == nil {
topicHandler := m.handlerForTopic(msg.Topic)
if topicHandler == nil {
m.logger.Errorf("no handler defined for topic %s", msg.Topic)
return
}
m.logger.Debugf("Processing MQTT message %s/%d (retained=%v)", mqttMsg.Topic(), mqttMsg.MessageID(), mqttMsg.Retained())
m.logger.Debugf("Processing MQTT message %s#%d (retained=%v)", mqttMsg.Topic(), mqttMsg.MessageID(), mqttMsg.Retained())
err := topicHandler(ctx, &msg)
if err != nil {
m.logger.Errorf("Failed processing MQTT message %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
m.logger.Errorf("Failed processing MQTT message %s#%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
return
}
m.logger.Debugf("Done processing MQTT message %s/%d; sending ACK", mqttMsg.Topic(), mqttMsg.MessageID())
m.logger.Debugf("Done processing MQTT message %s#%d; sending ACK", mqttMsg.Topic(), mqttMsg.MessageID())
ack = true
}
}
// Returns the handler for a message sent to a given topic, supporting wildcards and other special syntaxes.
func (m *mqttPubSub) handlerForTopic(topic string) pubsub.Handler {
m.subscribingLock.RLock()
defer m.subscribingLock.RUnlock()
// First, try to see if we have a handler for the exact topic (no wildcards etc)
topicHandler, ok := m.topics[topic]
if ok && topicHandler.handler != nil {
return topicHandler.handler
}
// Iterate through the topics and run the matchers
for _, obj := range m.topics {
if obj.alias == topic {
return obj.handler
}
if obj.matcher != nil && obj.matcher(topic) {
return obj.handler
}
}
return nil
}
func (m *mqttPubSub) connect(ctx context.Context, clientID string) (mqtt.Client, error) {
uri, err := url.Parse(m.metadata.url)
if err != nil {
@ -371,3 +412,74 @@ func (m *mqttPubSub) Close() error {
func (m *mqttPubSub) Features() []pubsub.Feature {
return nil
}
var sharedSubscriptionMatch = regexp.MustCompile(`^\$share\/(.*?)\/.`)
// Adds a topic to the list of subscriptions.
func (m *mqttPubSub) addTopic(origTopicName string, handler pubsub.Handler) {
obj := mqttPubSubSubscription{
handler: handler,
}
// Shared subscriptions begin with "$share/GROUPID/" and we can remove that prefix
topicName := origTopicName
if found := sharedSubscriptionMatch.FindStringIndex(origTopicName); found != nil && found[0] == 0 {
topicName = topicName[(found[1] - 1):]
obj.alias = topicName
}
// If the topic name contains a wildcard, we need to add a matcher
regexStr := buildRegexForTopic(topicName)
if regexStr != "" {
// We built our own regex and this should never panic
match := regexp.MustCompile(regexStr)
obj.matcher = func(topic string) bool {
return match.MatchString(topic)
}
}
m.topics[origTopicName] = obj
}
// Returns a regular expression string that matches the topic, with support for wildcards.
func buildRegexForTopic(topicName string) string {
// This is a bit more lax than the specs, which for example require "#" to be at the end of the string only:
// in practice, seems that (at least some) brokers are more flexible and allow "#" in the middle of a string too
var (
regexStr string
lastPos int = -1
okPos bool
)
if strings.ContainsAny(topicName, "#+") {
regexStr = "^"
// It's ok to iterate over bytes here (rather than codepoints) because all characters we're looking for are always single-byte
for i := 0; i < len(topicName); i++ {
// Wildcard chars must either be at the beginning of the string or must follow a /
okPos = (i == 0 || topicName[i-1] == '/')
if topicName[i] == '#' && okPos {
lastPos = i
if i > 0 && i == (len(topicName)-1) {
// Edge case: we're at the end of the string so we can allow omitting the preceding /
regexStr += regexp.QuoteMeta(topicName[0:(i-1)]) + "(.*)"
} else {
regexStr += regexp.QuoteMeta(topicName[0:i]) + "(.*)"
}
} else if topicName[i] == '+' && okPos {
lastPos = i
if i > 0 && i == (len(topicName)-1) {
// Edge case: we're at the end of the string so we can allow omitting the preceding /
regexStr += regexp.QuoteMeta(topicName[0:(i-1)]) + `((\/|)[^\/]*)`
} else {
regexStr += regexp.QuoteMeta(topicName[0:i]) + `([^\/]*)`
}
}
}
regexStr += regexp.QuoteMeta(topicName[(lastPos+1):]) + "$"
}
if lastPos == -1 {
return ""
}
return regexStr
}

View File

@ -17,6 +17,7 @@ import (
"crypto/x509"
"encoding/pem"
"errors"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
@ -185,3 +186,196 @@ func TestParseMetadata(t *testing.T) {
assert.NotNil(t, m.tlsCfg.clientKey, "failed to parse valid client certificate key")
})
}
func Test_buildRegexForTopic(t *testing.T) {
type args struct {
topicName string
}
tests := []struct {
name string
args args
regex string
tryMatches map[string]bool
}{
{
name: "no wildcard",
args: args{topicName: "hello world"},
regex: "",
},
{
name: "#",
args: args{topicName: "#"},
regex: "^(.*)$",
tryMatches: map[string]bool{
"helloworld": true,
"helloworld/": true,
"helloworld/22": true,
"/helloworld": true,
"/helloworld/": true,
"/helloworld/22": true,
"Ei fu. Siccome immobile, dato il mortal sospiro, stette la spoglia immemore.": true,
"🐶": true,
"🐶/foo": true,
"🐶/foo/bar": true,
},
},
{
// This should be forbidden by the specs, but apparently it works in brokers
name: "#/foo",
args: args{topicName: "#/foo"},
regex: "^(.*)/foo$",
tryMatches: map[string]bool{
"helloworld": false,
"helloworld/": false,
"helloworld/22": false,
"helloworld/foo": true,
"hello/world/foo": true,
"helloworld/foo/bar": false,
"/helloworld": false,
"/helloworld/": false,
"/helloworld/22": false,
"/helloworld/foo": true,
"/hello/world/foo": true,
"/helloworld/foo/bar": false,
"🐶": false,
"🐶/foo": true,
"🐶/😄/foo": true,
"🐶/foo/bar": false,
"🐶/😄": false,
},
},
{
name: "+",
args: args{topicName: "+"},
regex: `^([^\/]*)$`,
tryMatches: map[string]bool{
"helloworld": true,
"helloworld/": false,
"helloworld/22": false,
"/helloworld": false,
"/helloworld/": false,
"/helloworld/22": false,
"Ei fu. Siccome immobile, dato il mortal sospiro, stette la spoglia immemore.": true,
"🐶": true,
"🐶/foo": false,
"🐶/foo/bar": false,
},
},
{
name: "+/foo",
args: args{topicName: "+/foo"},
regex: `^([^\/]*)/foo$`,
tryMatches: map[string]bool{
"helloworld": false,
"helloworld/": false,
"helloworld/22": false,
"helloworld/foo": true,
"hello/world/foo": false,
"helloworld/foo/bar": false,
"/helloworld": false,
"/helloworld/": false,
"/helloworld/22": false,
"/helloworld/foo": false,
"/hello/world/foo": false,
"/helloworld/foo/bar": false,
"🐶": false,
"🐶/foo": true,
"🐶/😄/foo": false,
"🐶/foo/bar": false,
"🐶/😄": false,
},
},
{
name: "foo# (invalid)",
args: args{topicName: "foo#"},
regex: "",
},
{
name: "foo+ (invalid)",
args: args{topicName: "foo+"},
regex: "",
},
{
name: "foo/#",
args: args{topicName: "foo/#"},
regex: "^foo(.*)$",
tryMatches: map[string]bool{
"helloworld": false,
"foo": true,
"foo/": true,
"foo/bar": true,
"/helloworld": false,
"foo/helloworld": true,
"foo/hello/world": true,
"hello/world": false,
"🐶": false,
"foo/🐶": true,
"🐶/foo/bar": false,
"foo/🐶/bar": true,
},
},
{
// This should be forbidden by the specs, but apparently it works in brokers
name: "foo/#/bar",
args: args{topicName: "foo/#/bar"},
regex: "^foo/(.*)/bar$",
tryMatches: map[string]bool{
"helloworld": false,
"foo/": false,
"foo/bar": false,
"foo/hi/bar": true,
"foo/hi/hi/hi/bar": true,
"foo/hi/world": false,
},
},
{
name: "foo/+",
args: args{topicName: "foo/+"},
regex: `^foo((\/|)[^\/]*)$`,
tryMatches: map[string]bool{
"helloworld": false,
"foo": true,
"foo/": true,
"foo/bar": true,
"/helloworld": false,
"foo/helloworld": true,
"foo/hello/world": false,
"hello/world": false,
"🐶": false,
"foo/🐶": true,
"🐶/foo/bar": false,
"foo/🐶/bar": false,
},
},
{
name: "foo/+/bar",
args: args{topicName: "foo/+/bar"},
regex: `^foo/([^\/]*)/bar$`,
tryMatches: map[string]bool{
"helloworld": false,
"foo/": false,
"foo/bar": false,
"foo/hi/bar": true,
"foo/hi/hi/hi/bar": false,
"foo/hi/world": false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildRegexForTopic(tt.args.topicName)
if got != tt.regex {
t.Errorf("buildRegexForTopic(%v) = %v, want %v", tt.args.topicName, got, tt.regex)
return
}
if len(tt.tryMatches) > 0 {
re := regexp.MustCompile(got)
for topic, match := range tt.tryMatches {
if matched := re.MatchString(topic); matched != match {
t.Errorf("buildRegexForTopic(%v) - match(%v) returned %v but expected %v", tt.args.topicName, topic, matched, match)
}
}
}
})
}
}

View File

@ -11,8 +11,8 @@ spec:
- name: consumerID
value: "testConsumer1"
- name: retain
value: true
value: false
- name: qos
value: 2
value: 1
- name: cleanSession
value: false

View File

@ -0,0 +1,10 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: messagebus
spec:
type: pubsub.mqtt
version: v1
metadata:
- name: url
value: "tcp://localhost:1884"

View File

@ -15,13 +15,18 @@ package mqtt_test
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/cenkalti/backoff"
mqtt "github.com/eclipse/paho.mqtt.golang"
// "github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
@ -54,8 +59,10 @@ import (
const (
sidecarName1 = "dapr-1"
sidecarName2 = "dapr-2"
sidecarName3 = "dapr-3"
appID1 = "app-1"
appID2 = "app-2"
appID3 = "app-3"
clusterName = "mqttcertification"
dockerComposeYAML = "docker-compose.yml"
numMessages = 1000
@ -64,8 +71,12 @@ const (
messageKey = "partitionKey"
mqttURL = "tcp://localhost:1884"
pubsubName = "messagebus"
topicName = "neworder"
pubsubName = "messagebus"
topicName = "neworder"
wildcardTopicSubscribe = "orders/#"
wildcardTopicPublish = "orders/%s"
sharedTopicSubscribe = "$share/mygroup/mytopic/+/hello"
sharedTopicPublish = "mytopic/%s/hello"
)
var brokers = []string{"localhost:1884"}
@ -92,6 +103,11 @@ func mqttReady(url string) flow.Runnable {
func TestMQTT(t *testing.T) {
log := logger.NewLogger("dapr.components")
logger.ApplyOptionsToLoggers(&logger.Options{
OutputLevel: "debug",
})
component := pubsub_loader.New("mqtt", func() pubsub.PubSub {
return pubsub_mqtt.NewMQTTPubSub(log)
})
@ -99,9 +115,11 @@ func TestMQTT(t *testing.T) {
// In-order processing not guaranteed
consumerGroup1 := watcher.NewUnordered()
consumerGroup2 := watcher.NewUnordered()
consumerGroupMultiWildcard := watcher.NewUnordered()
consumerGroupMultiShared := watcher.NewUnordered()
// Application logic that tracks messages from a topic.
application := func(messages *watcher.Watcher, appID string) app.SetupFn {
application := func(messages *watcher.Watcher, appID string, topicName string) app.SetupFn {
return func(ctx flow.Context, s common.Service) error {
// Simulate periodic errors.
sim := simulate.PeriodicError(ctx, 100)
@ -126,9 +144,37 @@ func TestMQTT(t *testing.T) {
}
}
// Application logic that subscribes to multiple topics
applicationMultiTopic := func(appID string, subs ...topicSubscription) app.SetupFn {
return func(ctx flow.Context, s common.Service) (err error) {
handlerGen := func(name string, messages *watcher.Watcher) func(_ context.Context, e *common.TopicEvent) (retry bool, err error) {
return func(_ context.Context, e *common.TopicEvent) (retry bool, err error) {
messages.Observe(e.Data)
ctx.Logf("%s/%s Event - pubsub: %s, topic: %s, id: %s, data: %s", appID, name,
e.PubsubName, e.Topic, e.ID, e.Data)
return false, nil
}
}
for _, sub := range subs {
err = s.AddTopicEventHandler(
&common.Subscription{
PubsubName: pubsubName,
Topic: sub.name,
Route: sub.route,
}, handlerGen(sub.name, sub.messages),
)
if err != nil {
return err
}
}
return nil
}
}
// Test logic that sends messages to a topic and
// verifies the application has received them.
test := func(messages ...*watcher.Watcher) flow.Runnable {
test := func(topicName string, messages ...*watcher.Watcher) flow.Runnable {
return func(ctx flow.Context) error {
client := sidecar.GetClient(ctx, sidecarName1)
@ -136,7 +182,7 @@ func TestMQTT(t *testing.T) {
// that will satisfy the test.
msgs := make([]string, numMessages)
for i := range msgs {
msgs[i] = fmt.Sprintf("Hello, Messages %03d", i)
msgs[i] = fmt.Sprintf("Hello, Messages %s#%03d", topicName, i)
}
for _, m := range messages {
m.ExpectStrings(msgs...)
@ -145,9 +191,13 @@ func TestMQTT(t *testing.T) {
// Send events that the application above will observe.
ctx.Log("Sending messages!")
for _, msg := range msgs {
ctx.Logf("Sending: %q", msg)
err := client.PublishEvent(
ctx, pubsubName, topicName, msg)
// If topicName has a %s, this will add some randomness (if not, it won't be changed)
tn := topicName
if strings.Contains(tn, "%s") {
tn = fmt.Sprintf(tn, randomStr())
}
ctx.Logf("Sending '%q' to topic '%s'", msg, tn)
err := client.PublishEvent(ctx, pubsubName, tn, msg)
require.NoError(ctx, err, "error publishing message")
}
@ -160,11 +210,11 @@ func TestMQTT(t *testing.T) {
}
}
multiple_test := func(messages ...*watcher.Watcher) flow.Runnable {
multipleTest := func(messages ...*watcher.Watcher) flow.Runnable {
return func(ctx flow.Context) error {
var wg sync.WaitGroup
wg.Add(2)
publish_msgs := func(sidecarName string) {
publishMsgs := func(sidecarName string) {
defer wg.Done()
client := sidecar.GetClient(ctx, sidecarName)
msgs := make([]string, numMessages/2)
@ -182,8 +232,8 @@ func TestMQTT(t *testing.T) {
require.NoError(ctx, err, "error publishing message")
}
}
go publish_msgs(sidecarName1)
go publish_msgs(sidecarName2)
go publishMsgs(sidecarName1)
go publishMsgs(sidecarName2)
wg.Wait()
// Do the messages we observed match what we expect?
@ -194,6 +244,7 @@ func TestMQTT(t *testing.T) {
return nil
}
}
// sendMessagesInBackground and assertMessages are
// Runnables for testing publishing and consuming
// messages reliably when infrastructure and network
@ -267,7 +318,7 @@ func TestMQTT(t *testing.T) {
//
// Run the application logic above(App1)
Step(app.Run(appID1, fmt.Sprintf(":%d", appPort),
application(consumerGroup1, appID1))).
application(consumerGroup1, appID1, topicName))).
// Run the Dapr sidecar with the MQTTPubSub component.
Step(sidecar.Run(sidecarName1,
embedded.WithComponentsPath("./components/consumer1"),
@ -277,12 +328,12 @@ func TestMQTT(t *testing.T) {
runtime.WithPubSubs(component))).
//
// Send messages and test
Step("send and wait", test(consumerGroup1)).
Step("send and wait", test(topicName, consumerGroup1)).
Step("reset", flow.Reset(consumerGroup1)).
//
//Run Second application App2
Step(app.Run(appID2, fmt.Sprintf(":%d", appPort+portOffset),
application(consumerGroup2, appID2))).
application(consumerGroup2, appID2, topicName))).
// Run the Dapr sidecar with the MQTTPubSub component.
Step(sidecar.Run(sidecarName2,
embedded.WithComponentsPath("./components/consumer2"),
@ -293,9 +344,31 @@ func TestMQTT(t *testing.T) {
runtime.WithPubSubs(component))).
//
// Send messages and test
Step("multiple send and wait", multiple_test(consumerGroup1, consumerGroup2)).
Step("multiple send and wait", multipleTest(consumerGroup1, consumerGroup2)).
Step("reset", flow.Reset(consumerGroup1, consumerGroup2)).
//
// Test multiple topics and wildcards
Step(
app.Run(
appID3,
fmt.Sprintf(":%d", appPort+(portOffset*3)),
applicationMultiTopic(
appID3,
topicSubscription{messages: consumerGroupMultiWildcard, name: wildcardTopicSubscribe, route: "/wildcard"},
topicSubscription{messages: consumerGroupMultiShared, name: sharedTopicSubscribe, route: "/shared"},
),
),
).
Step(sidecar.Run(sidecarName3,
embedded.WithComponentsPath("./components/consumer3"),
embedded.WithAppProtocol(runtime.HTTPProtocol, appPort+(portOffset*3)),
embedded.WithDaprGRPCPort(runtime.DefaultDaprAPIGRPCPort+(portOffset*3)),
embedded.WithDaprHTTPPort(runtime.DefaultDaprHTTPPort+(portOffset*3)),
embedded.WithProfilePort(runtime.DefaultProfilePort+(portOffset*3)),
runtime.WithPubSubs(component))).
Step("send and wait wildcard", test(wildcardTopicPublish, consumerGroupMultiWildcard)).
Step("send and wait shared", test(sharedTopicPublish, consumerGroupMultiShared)).
//
// Infra test
StepAsync("steady flow of messages to publish", &task,
sendMessagesInBackground(consumerGroup1, consumerGroup2)).
@ -338,3 +411,15 @@ func TestMQTT(t *testing.T) {
Step("assert messages", assertMessages(consumerGroup1, consumerGroup2)).
Run()
}
type topicSubscription struct {
messages *watcher.Watcher
name string
route string
}
func randomStr() string {
buf := make([]byte, 4)
_, _ = io.ReadFull(rand.Reader, buf)
return hex.EncodeToString(buf)
}