374 lines
11 KiB
Go
374 lines
11 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package mqtt
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
mqtt "github.com/eclipse/paho.mqtt.golang"
|
|
|
|
"github.com/dapr/components-contrib/bindings"
|
|
"github.com/dapr/kit/logger"
|
|
"github.com/dapr/kit/retry"
|
|
)
|
|
|
|
const (
|
|
// Keys.
|
|
mqttURL = "url"
|
|
mqttTopic = "topic"
|
|
mqttQOS = "qos"
|
|
mqttRetain = "retain"
|
|
mqttClientID = "consumerID"
|
|
mqttCleanSession = "cleanSession"
|
|
mqttCACert = "caCert"
|
|
mqttClientCert = "clientCert"
|
|
mqttClientKey = "clientKey"
|
|
mqttBackOffMaxRetries = "backOffMaxRetries"
|
|
|
|
// errors.
|
|
errorMsgPrefix = "mqtt binding error:"
|
|
|
|
// Defaults.
|
|
defaultQOS = 0
|
|
defaultRetain = false
|
|
defaultWait = 3 * time.Second
|
|
defaultCleanSession = true
|
|
)
|
|
|
|
// MQTT allows sending and receiving data to/from an MQTT broker.
|
|
type MQTT struct {
|
|
producer mqtt.Client
|
|
consumer mqtt.Client
|
|
metadata *metadata
|
|
logger logger.Logger
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
backOff backoff.BackOff
|
|
}
|
|
|
|
// NewMQTT returns a new MQTT instance.
|
|
func NewMQTT(logger logger.Logger) *MQTT {
|
|
return &MQTT{logger: logger}
|
|
}
|
|
|
|
// isValidPEM validates the provided input has PEM formatted block.
|
|
func isValidPEM(val string) bool {
|
|
block, _ := pem.Decode([]byte(val))
|
|
|
|
return block != nil
|
|
}
|
|
|
|
func parseMQTTMetaData(md bindings.Metadata) (*metadata, error) {
|
|
m := metadata{}
|
|
|
|
// required configuration settings
|
|
if val, ok := md.Properties[mqttURL]; ok && val != "" {
|
|
m.url = val
|
|
} else {
|
|
return &m, fmt.Errorf("%s missing url", errorMsgPrefix)
|
|
}
|
|
|
|
if val, ok := md.Properties[mqttTopic]; ok && val != "" {
|
|
m.topic = val
|
|
} else {
|
|
return &m, fmt.Errorf("%s missing topic", errorMsgPrefix)
|
|
}
|
|
|
|
// optional configuration settings
|
|
m.qos = defaultQOS
|
|
if val, ok := md.Properties[mqttQOS]; ok && val != "" {
|
|
qosInt, err := strconv.Atoi(val)
|
|
if err != nil {
|
|
return &m, fmt.Errorf("%s invalid qos %s, %s", errorMsgPrefix, val, err)
|
|
}
|
|
m.qos = byte(qosInt)
|
|
}
|
|
|
|
m.retain = defaultRetain
|
|
if val, ok := md.Properties[mqttRetain]; ok && val != "" {
|
|
var err error
|
|
m.retain, err = strconv.ParseBool(val)
|
|
if err != nil {
|
|
return &m, fmt.Errorf("%s invalid retain %s, %s", errorMsgPrefix, val, err)
|
|
}
|
|
}
|
|
|
|
if val, ok := md.Properties[mqttClientID]; ok && val != "" {
|
|
m.clientID = val
|
|
} else {
|
|
return &m, fmt.Errorf("%s missing consumerID", errorMsgPrefix)
|
|
}
|
|
|
|
m.cleanSession = defaultCleanSession
|
|
if val, ok := md.Properties[mqttCleanSession]; ok && val != "" {
|
|
var err error
|
|
m.cleanSession, err = strconv.ParseBool(val)
|
|
if err != nil {
|
|
return &m, fmt.Errorf("%s invalid clean session %s, %s", errorMsgPrefix, val, err)
|
|
}
|
|
}
|
|
|
|
if val, ok := md.Properties[mqttCACert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return &m, fmt.Errorf("%s invalid ca certificate", errorMsgPrefix)
|
|
}
|
|
m.tlsCfg.caCert = val
|
|
}
|
|
if val, ok := md.Properties[mqttClientCert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return &m, fmt.Errorf("%s invalid client certificate", errorMsgPrefix)
|
|
}
|
|
m.tlsCfg.clientCert = val
|
|
}
|
|
if val, ok := md.Properties[mqttClientKey]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return &m, fmt.Errorf("%s invalid client certificate key", errorMsgPrefix)
|
|
}
|
|
m.tlsCfg.clientKey = val
|
|
}
|
|
|
|
if val, ok := md.Properties[mqttBackOffMaxRetries]; ok && val != "" {
|
|
backOffMaxRetriesInt, err := strconv.Atoi(val)
|
|
if err != nil {
|
|
return &m, fmt.Errorf("%s invalid backOffMaxRetries %s, %s", errorMsgPrefix, val, err)
|
|
}
|
|
m.backOffMaxRetries = backOffMaxRetriesInt
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
// Init does MQTT connection parsing.
|
|
func (m *MQTT) Init(metadata bindings.Metadata) error {
|
|
mqttMeta, err := parseMQTTMetaData(metadata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.metadata = mqttMeta
|
|
|
|
// mqtt broker allows only one connection at a given time from a clientID.
|
|
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
|
|
p, err := m.connect(producerClientID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m.ctx, m.cancel = context.WithCancel(context.Background())
|
|
|
|
// TODO: Make the backoff configurable for constant or exponential
|
|
b := backoff.NewConstantBackOff(5 * time.Second)
|
|
m.backOff = backoff.WithContext(b, m.ctx)
|
|
|
|
m.producer = p
|
|
|
|
m.logger.Debug("mqtt message bus initialization complete")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MQTT) Operations() []bindings.OperationKind {
|
|
return []bindings.OperationKind{bindings.CreateOperation}
|
|
}
|
|
|
|
func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
|
// MQTT client Publish() has an internal race condition in the default autoreconnect config.
|
|
// To mitigate sporadic failures on the Dapr side, this implementation retries 3 times at
|
|
// a fixed 200ms interval. This is not configurable to keep this as an implementation detail
|
|
// for this component, as the additional public config metadata required could be replaced
|
|
// by the more general Dapr APIs for resiliency moving forwards.
|
|
cbo := backoff.NewConstantBackOff(200 * time.Millisecond)
|
|
bo := backoff.WithMaxRetries(cbo, 3)
|
|
bo = backoff.WithContext(bo, ctx)
|
|
|
|
return nil, retry.NotifyRecover(func() error {
|
|
topic, ok := req.Metadata[mqttTopic]
|
|
if !ok || topic == "" {
|
|
// If user does not specify a topic, publish via the component's default topic.
|
|
topic = m.metadata.topic
|
|
}
|
|
m.logger.Debugf("mqtt publishing topic %s with data: %v", topic, req.Data)
|
|
token := m.producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
|
if !token.WaitTimeout(defaultWait) || token.Error() != nil {
|
|
return fmt.Errorf("mqtt error from publish: %v", token.Error())
|
|
}
|
|
return nil
|
|
}, bo, func(err error, _ time.Duration) {
|
|
m.logger.Debugf("Could not publish MQTT message. Retrying...: %v", err)
|
|
}, func() {
|
|
m.logger.Debug("Successfully published MQTT message after it previously failed")
|
|
})
|
|
}
|
|
|
|
func (m *MQTT) handleMessage(handler bindings.Handler, mqttMsg mqtt.Message) error {
|
|
msg := bindings.ReadResponse{
|
|
Data: mqttMsg.Payload(),
|
|
Metadata: map[string]string{mqttTopic: mqttMsg.Topic()},
|
|
}
|
|
|
|
// paho.mqtt.golang requires that handlers never block or it can deadlock on client.Disconnect.
|
|
// To ensure that the Dapr runtime does not hang on teardown on of the component, run the app's
|
|
// handling code in a goroutine so that this handler function is always cancellable on Close().
|
|
ch := make(chan error)
|
|
go func(m *bindings.ReadResponse) {
|
|
defer close(ch)
|
|
_, err := handler(context.TODO(), m)
|
|
ch <- err
|
|
}(&msg)
|
|
|
|
select {
|
|
case handlerErr := <-ch:
|
|
if handlerErr != nil {
|
|
return handlerErr
|
|
}
|
|
mqttMsg.Ack()
|
|
return nil
|
|
case <-m.ctx.Done():
|
|
m.logger.Infof("Read context cancelled: %v", m.ctx.Err())
|
|
return m.ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (m *MQTT) Read(handler bindings.Handler) error {
|
|
sigterm := make(chan os.Signal, 1)
|
|
signal.Notify(sigterm, os.Interrupt, syscall.SIGTERM)
|
|
|
|
// reset synchronization
|
|
if m.consumer != nil {
|
|
m.logger.Warnf("re-initializing the subscriber")
|
|
m.consumer.Disconnect(0)
|
|
m.consumer = nil
|
|
}
|
|
|
|
// mqtt broker allows only one connection at a given time from a clientID.
|
|
consumerClientID := fmt.Sprintf("%s-consumer", m.metadata.clientID)
|
|
c, err := m.connect(consumerClientID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.consumer = c
|
|
|
|
m.logger.Debugf("mqtt subscribing to topic %s", m.metadata.topic)
|
|
token := m.consumer.Subscribe(m.metadata.topic, m.metadata.qos, func(client mqtt.Client, mqttMsg mqtt.Message) {
|
|
b := m.backOff
|
|
if m.metadata.backOffMaxRetries >= 0 {
|
|
b = backoff.WithMaxRetries(m.backOff, uint64(m.metadata.backOffMaxRetries))
|
|
}
|
|
|
|
if err := retry.NotifyRecover(func() error {
|
|
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
|
return m.handleMessage(handler, mqttMsg)
|
|
}, b, func(err error, d time.Duration) {
|
|
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying...", mqttMsg.Topic(), mqttMsg.MessageID())
|
|
}, func() {
|
|
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
|
}); err != nil {
|
|
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
|
|
}
|
|
})
|
|
if err := token.Error(); err != nil {
|
|
m.logger.Errorf("mqtt error from subscribe: %v", err)
|
|
|
|
return err
|
|
}
|
|
<-sigterm
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MQTT) connect(clientID string) (mqtt.Client, error) {
|
|
uri, err := url.Parse(m.metadata.url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts := m.createClientOptions(uri, clientID)
|
|
client := mqtt.NewClient(opts)
|
|
token := client.Connect()
|
|
for !token.WaitTimeout(defaultWait) {
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (m *MQTT) newTLSConfig() *tls.Config {
|
|
tlsConfig := new(tls.Config)
|
|
|
|
if m.metadata.clientCert != "" && m.metadata.clientKey != "" {
|
|
cert, err := tls.X509KeyPair([]byte(m.metadata.clientCert), []byte(m.metadata.clientKey))
|
|
if err != nil {
|
|
m.logger.Warnf("unable to load client certificate and key pair. Err: %v", err)
|
|
|
|
return tlsConfig
|
|
}
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
}
|
|
|
|
if m.metadata.caCert != "" {
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
if ok := tlsConfig.RootCAs.AppendCertsFromPEM([]byte(m.metadata.caCert)); !ok {
|
|
m.logger.Warnf("unable to load ca certificate.")
|
|
}
|
|
}
|
|
|
|
return tlsConfig
|
|
}
|
|
|
|
func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOptions {
|
|
opts := mqtt.NewClientOptions()
|
|
opts.SetClientID(clientID)
|
|
opts.SetCleanSession(m.metadata.cleanSession)
|
|
// URL scheme backward compatibility
|
|
scheme := uri.Scheme
|
|
switch scheme {
|
|
case "mqtt":
|
|
scheme = "tcp"
|
|
case "mqtts", "tcps", "tls":
|
|
scheme = "ssl"
|
|
}
|
|
opts.AddBroker(scheme + "://" + uri.Host)
|
|
opts.SetUsername(uri.User.Username())
|
|
password, _ := uri.User.Password()
|
|
opts.SetPassword(password)
|
|
// tls config
|
|
opts.SetTLSConfig(m.newTLSConfig())
|
|
|
|
return opts
|
|
}
|
|
|
|
func (m *MQTT) Close() error {
|
|
// Cancel any read callback handlers before Disconnect to prevent deadlocks.
|
|
m.cancel()
|
|
|
|
if m.consumer != nil {
|
|
m.consumer.Disconnect(1)
|
|
}
|
|
m.producer.Disconnect(1)
|
|
|
|
return nil
|
|
}
|