components-contrib/pubsub/mqtt/mqtt.go

331 lines
8.7 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mqtt
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net/url"
"strconv"
"time"
"github.com/cenkalti/backoff/v4"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/retry"
)
const (
// Keys.
mqttURL = "url"
mqttQOS = "qos"
mqttRetain = "retain"
mqttClientID = "consumerID"
mqttCleanSession = "cleanSession"
mqttCACert = "caCert"
mqttClientCert = "clientCert"
mqttClientKey = "clientKey"
mqttBackOffMaxRetries = "backOffMaxRetries"
// errors.
errorMsgPrefix = "mqtt pub sub error:"
// Defaults.
defaultQOS = 0
defaultRetain = false
defaultWait = 3 * time.Second
defaultCleanSession = true
)
// mqttPubSub type allows sending and receiving data to/from MQTT broker.
type mqttPubSub struct {
producer mqtt.Client
consumer mqtt.Client
metadata *metadata
logger logger.Logger
topics map[string]byte
ctx context.Context
cancel context.CancelFunc
backOff backoff.BackOff
}
// NewMQTTPubSub returns a new mqttPubSub instance.
func NewMQTTPubSub(logger logger.Logger) pubsub.PubSub {
return &mqttPubSub{logger: logger}
}
// isValidPEM validates the provided input has PEM formatted block.
func isValidPEM(val string) bool {
block, _ := pem.Decode([]byte(val))
return block != nil
}
func parseMQTTMetaData(md pubsub.Metadata) (*metadata, error) {
m := metadata{}
// required configuration settings
if val, ok := md.Properties[mqttURL]; ok && val != "" {
m.url = val
} else {
return &m, fmt.Errorf("%s missing url", errorMsgPrefix)
}
// optional configuration settings
m.qos = defaultQOS
if val, ok := md.Properties[mqttQOS]; ok && val != "" {
qosInt, err := strconv.Atoi(val)
if err != nil {
return &m, fmt.Errorf("%s invalid qos %s, %s", errorMsgPrefix, val, err)
}
m.qos = byte(qosInt)
}
m.retain = defaultRetain
if val, ok := md.Properties[mqttRetain]; ok && val != "" {
var err error
m.retain, err = strconv.ParseBool(val)
if err != nil {
return &m, fmt.Errorf("%s invalid retain %s, %s", errorMsgPrefix, val, err)
}
}
if val, ok := md.Properties[mqttClientID]; ok && val != "" {
m.clientID = val
} else {
return &m, fmt.Errorf("%s missing consumerID", errorMsgPrefix)
}
m.cleanSession = defaultCleanSession
if val, ok := md.Properties[mqttCleanSession]; ok && val != "" {
var err error
m.cleanSession, err = strconv.ParseBool(val)
if err != nil {
return &m, fmt.Errorf("%s invalid clean session %s, %s", errorMsgPrefix, val, err)
}
}
if val, ok := md.Properties[mqttCACert]; ok && val != "" {
if !isValidPEM(val) {
return &m, fmt.Errorf("%s invalid ca certificate", errorMsgPrefix)
}
m.tlsCfg.caCert = val
}
if val, ok := md.Properties[mqttClientCert]; ok && val != "" {
if !isValidPEM(val) {
return &m, fmt.Errorf("%s invalid client certificate", errorMsgPrefix)
}
m.tlsCfg.clientCert = val
}
if val, ok := md.Properties[mqttClientKey]; ok && val != "" {
if !isValidPEM(val) {
return &m, fmt.Errorf("%s invalid client certificate key", errorMsgPrefix)
}
m.tlsCfg.clientKey = val
}
if val, ok := md.Properties[mqttBackOffMaxRetries]; ok && val != "" {
backOffMaxRetriesInt, err := strconv.Atoi(val)
if err != nil {
return &m, fmt.Errorf("%s invalid backOffMaxRetries %s, %s", errorMsgPrefix, val, err)
}
m.backOffMaxRetries = backOffMaxRetriesInt
}
return &m, nil
}
// Init parses metadata and creates a new Pub Sub client.
func (m *mqttPubSub) Init(metadata pubsub.Metadata) error {
mqttMeta, err := parseMQTTMetaData(metadata)
if err != nil {
return err
}
m.metadata = mqttMeta
// mqtt broker allows only one connection at a given time from a clientID.
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
p, err := m.connect(producerClientID)
if err != nil {
return err
}
m.ctx, m.cancel = context.WithCancel(context.Background())
// TODO: Make the backoff configurable for constant or exponential
b := backoff.NewConstantBackOff(5 * time.Second)
m.backOff = backoff.WithContext(b, m.ctx)
m.producer = p
m.topics = make(map[string]byte)
m.logger.Debug("mqtt message bus initialization complete")
return nil
}
// Publish the topic to mqtt pub sub.
func (m *mqttPubSub) Publish(req *pubsub.PublishRequest) error {
m.logger.Debugf("mqtt publishing topic %s with data: %v", req.Topic, req.Data)
token := m.producer.Publish(req.Topic, m.metadata.qos, m.metadata.retain, req.Data)
if !token.WaitTimeout(defaultWait) || token.Error() != nil {
return fmt.Errorf("mqtt error from publish: %v", token.Error())
}
return nil
}
// Subscribe to the mqtt pub sub topic.
func (m *mqttPubSub) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
m.topics[req.Topic] = m.metadata.qos
// reset synchronization
if m.consumer != nil {
m.logger.Warnf("re-initializing the subscriber")
m.consumer.Disconnect(0)
m.consumer = nil
}
// mqtt broker allows only one connection at a given time from a clientID.
consumerClientID := fmt.Sprintf("%s-consumer", m.metadata.clientID)
c, err := m.connect(consumerClientID)
if err != nil {
return err
}
m.consumer = c
go func() {
token := m.consumer.SubscribeMultiple(
m.topics,
func(client mqtt.Client, mqttMsg mqtt.Message) {
mqttMsg.AutoAckOff()
msg := pubsub.NewMessage{
Topic: mqttMsg.Topic(),
Data: mqttMsg.Payload(),
}
b := m.backOff
if m.metadata.backOffMaxRetries >= 0 {
b = backoff.WithMaxRetries(m.backOff, uint64(m.metadata.backOffMaxRetries))
}
if err := retry.NotifyRecover(func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
if err := handler(m.ctx, &msg); err != nil {
return err
}
mqttMsg.Ack()
return nil
}, b, func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying...", mqttMsg.Topic(), mqttMsg.MessageID())
}, func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
}); err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
}
},
)
if err := token.Error(); err != nil {
m.logger.Errorf("mqtt error from subscribe: %v", err)
}
}()
return nil
}
func (m *mqttPubSub) connect(clientID string) (mqtt.Client, error) {
uri, err := url.Parse(m.metadata.url)
if err != nil {
return nil, err
}
opts := m.createClientOptions(uri, clientID)
client := mqtt.NewClient(opts)
token := client.Connect()
for !token.WaitTimeout(defaultWait) {
}
if err := token.Error(); err != nil {
return nil, err
}
return client, nil
}
func (m *mqttPubSub) newTLSConfig() *tls.Config {
tlsConfig := new(tls.Config)
if m.metadata.clientCert != "" && m.metadata.clientKey != "" {
cert, err := tls.X509KeyPair([]byte(m.metadata.clientCert), []byte(m.metadata.clientKey))
if err != nil {
m.logger.Warnf("unable to load client certificate and key pair. Err: %v", err)
return tlsConfig
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if m.metadata.caCert != "" {
tlsConfig.RootCAs = x509.NewCertPool()
if ok := tlsConfig.RootCAs.AppendCertsFromPEM([]byte(m.metadata.caCert)); !ok {
m.logger.Warnf("unable to load ca certificate.")
}
}
return tlsConfig
}
func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions()
opts.SetClientID(clientID)
opts.SetCleanSession(m.metadata.cleanSession)
// URL scheme backward compatibility
scheme := uri.Scheme
switch scheme {
case "mqtt":
scheme = "tcp"
case "mqtts", "tcps", "tls":
scheme = "ssl"
}
opts.AddBroker(scheme + "://" + uri.Host)
opts.SetUsername(uri.User.Username())
password, _ := uri.User.Password()
opts.SetPassword(password)
// tls config
opts.SetTLSConfig(m.newTLSConfig())
return opts
}
func (m *mqttPubSub) Close() error {
m.cancel()
if m.consumer != nil {
m.consumer.Disconnect(0)
}
m.producer.Disconnect(0)
return nil
}
func (m *mqttPubSub) Features() []pubsub.Feature {
return nil
}