Lots of fixes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2023-01-26 00:26:00 +00:00
parent dbf8426993
commit 64087d7f4b
2 changed files with 70 additions and 56 deletions

View File

@ -62,20 +62,20 @@ type tlsCfg struct {
clientKey string
}
func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, error) {
func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (metadata, error) {
m := metadata{}
// required configuration settings
if val, ok := md.Properties[mqttURL]; ok && val != "" {
m.url = val
} else {
return &m, errors.New("missing url")
return m, errors.New("missing url")
}
if val, ok := md.Properties[mqttTopic]; ok && val != "" {
m.topic = val
} else {
return &m, errors.New("missing topic")
return m, errors.New("missing topic")
}
// optional configuration settings
@ -87,7 +87,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
if val, ok := md.Properties[mqttClientID]; ok && val != "" {
m.clientID = val
} else {
return &m, errors.New("missing consumerID")
return m, errors.New("missing consumerID")
}
m.cleanSession = defaultCleanSession
@ -97,19 +97,19 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
if val, ok := md.Properties[mqttCACert]; ok && val != "" {
if !isValidPEM(val) {
return &m, errors.New("invalid ca certificate")
return m, errors.New("invalid ca certificate")
}
m.tlsCfg.caCert = val
}
if val, ok := md.Properties[mqttClientCert]; ok && val != "" {
if !isValidPEM(val) {
return &m, errors.New("invalid client certificate")
return m, errors.New("invalid client certificate")
}
m.tlsCfg.clientCert = val
}
if val, ok := md.Properties[mqttClientKey]; ok && val != "" {
if !isValidPEM(val) {
return &m, errors.New("invalid client certificate key")
return m, errors.New("invalid client certificate key")
}
m.tlsCfg.clientKey = val
}
@ -117,7 +117,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
if val, ok := md.Properties[mqttBackOffMaxRetries]; ok && val != "" {
backOffMaxRetriesInt, err := strconv.Atoi(val)
if err != nil {
return &m, fmt.Errorf("invalid backOffMaxRetries %s: %v", val, err)
return m, fmt.Errorf("invalid backOffMaxRetries %s: %v", val, err)
}
m.backOffMaxRetries = backOffMaxRetriesInt
}
@ -128,7 +128,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
log.Warn("Metadata property 'qos' has been deprecated and is ignored; qos is set to 1")
}
return &m, nil
return m, nil
}
// isValidPEM validates the provided input has PEM formatted block.

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"net/url"
"sync"
"sync/atomic"
"time"
@ -34,7 +35,8 @@ import (
// MQTT allows sending and receiving data to/from an MQTT broker.
type MQTT struct {
producer mqtt.Client
metadata *metadata
producerLock sync.RWMutex
metadata metadata
logger logger.Logger
isSubscribed atomic.Bool
readHandler bindings.Handler
@ -51,28 +53,18 @@ func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
}
// Init does MQTT connection parsing.
func (m *MQTT) Init(metadata bindings.Metadata) error {
mqttMeta, err := parseMQTTMetaData(metadata, m.logger)
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
if err != nil {
return err
}
m.metadata = mqttMeta
m.ctx, m.cancel = context.WithCancel(context.Background())
// mqtt broker allows only one connection at a given time from a clientID.
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
p, err := m.connect(producerClientID, false)
if err != nil {
return err
}
// TODO: Make the backoff configurable for constant or exponential
b := backoff.NewConstantBackOff(5 * time.Second)
m.backOff = backoff.WithContext(b, m.ctx)
m.producer = p
return nil
}
@ -82,7 +74,42 @@ func (m *MQTT) Operations() []bindings.OperationKind {
}
}
func (m *MQTT) getProducer() (mqtt.Client, error) {
// Get the producer from the cache
m.producerLock.RLock()
producer := m.producer
m.producerLock.RUnlock()
if producer != nil {
return producer, nil
}
// Must create a new producer
m.producerLock.Lock()
defer m.producerLock.Unlock()
// Check again in case another goroutine created it in the meanwhile
producer = m.producer
if producer != nil {
return producer, nil
}
// mqtt broker allows only one connection at a given time from a clientID.
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
p, err := m.connect(producerClientID, false)
if err != nil {
return nil, err
}
m.producer = p
return p, nil
}
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
producer, err := m.getProducer()
if err != nil {
return nil, fmt.Errorf("failed to create producer connection: %w", err)
}
// MQTT client Publish() has an internal race condition in the default autoreconnect config.
// To mitigate sporadic failures on the Dapr side, this implementation retries 3 times at
// a fixed 200ms interval. This is not configurable to keep this as an implementation detail
@ -99,7 +126,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
topic = m.metadata.topic
}
return nil, retry.NotifyRecover(func() (err error) {
token := m.producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
defer cancel()
select {
@ -123,37 +150,6 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
})
}
func (m *MQTT) handleMessage(ctx context.Context, handler bindings.Handler, mqttMsg mqtt.Message) error {
msg := bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
}
// paho.mqtt.golang requires that handlers never block or it can deadlock on client.Disconnect.
// To ensure that the Dapr runtime does not hang on teardown on of the component, run the app's
// handling code in a goroutine so that this handler function is always cancellable on Close().
ch := make(chan error)
go func(m *bindings.ReadResponse) {
defer close(ch)
_, err := handler(ctx, m)
ch <- err
}(&msg)
select {
case handlerErr := <-ch:
if handlerErr != nil {
return handlerErr
}
mqttMsg.Ack()
return nil
case <-m.ctx.Done():
m.logger.Infof("Read context cancelled: %v", m.ctx.Err())
return m.ctx.Err()
}
}
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
if !m.isSubscribed.CompareAndSwap(false, true) {
@ -313,7 +309,19 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
err := retry.NotifyRecover(
func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
return m.handleMessage(ctx, m.readHandler, mqttMsg)
_, err := m.readHandler(ctx, &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
// Ack the message on success
mqttMsg.Ack()
return nil
},
bo,
func(err error, d time.Duration) {
@ -350,10 +358,16 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
}
func (m *MQTT) Close() error {
m.producerLock.Lock()
defer m.producerLock.Unlock()
// Canceling the context also causes Read to stop receiving messages
m.cancel()
m.producer.Disconnect(200)
if m.producer != nil {
m.producer.Disconnect(200)
m.producer = nil
}
return nil
}