Lots of fixes
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
dbf8426993
commit
64087d7f4b
|
@ -62,20 +62,20 @@ type tlsCfg struct {
|
|||
clientKey string
|
||||
}
|
||||
|
||||
func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, error) {
|
||||
func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (metadata, error) {
|
||||
m := metadata{}
|
||||
|
||||
// required configuration settings
|
||||
if val, ok := md.Properties[mqttURL]; ok && val != "" {
|
||||
m.url = val
|
||||
} else {
|
||||
return &m, errors.New("missing url")
|
||||
return m, errors.New("missing url")
|
||||
}
|
||||
|
||||
if val, ok := md.Properties[mqttTopic]; ok && val != "" {
|
||||
m.topic = val
|
||||
} else {
|
||||
return &m, errors.New("missing topic")
|
||||
return m, errors.New("missing topic")
|
||||
}
|
||||
|
||||
// optional configuration settings
|
||||
|
@ -87,7 +87,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
|
|||
if val, ok := md.Properties[mqttClientID]; ok && val != "" {
|
||||
m.clientID = val
|
||||
} else {
|
||||
return &m, errors.New("missing consumerID")
|
||||
return m, errors.New("missing consumerID")
|
||||
}
|
||||
|
||||
m.cleanSession = defaultCleanSession
|
||||
|
@ -97,19 +97,19 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
|
|||
|
||||
if val, ok := md.Properties[mqttCACert]; ok && val != "" {
|
||||
if !isValidPEM(val) {
|
||||
return &m, errors.New("invalid ca certificate")
|
||||
return m, errors.New("invalid ca certificate")
|
||||
}
|
||||
m.tlsCfg.caCert = val
|
||||
}
|
||||
if val, ok := md.Properties[mqttClientCert]; ok && val != "" {
|
||||
if !isValidPEM(val) {
|
||||
return &m, errors.New("invalid client certificate")
|
||||
return m, errors.New("invalid client certificate")
|
||||
}
|
||||
m.tlsCfg.clientCert = val
|
||||
}
|
||||
if val, ok := md.Properties[mqttClientKey]; ok && val != "" {
|
||||
if !isValidPEM(val) {
|
||||
return &m, errors.New("invalid client certificate key")
|
||||
return m, errors.New("invalid client certificate key")
|
||||
}
|
||||
m.tlsCfg.clientKey = val
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
|
|||
if val, ok := md.Properties[mqttBackOffMaxRetries]; ok && val != "" {
|
||||
backOffMaxRetriesInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return &m, fmt.Errorf("invalid backOffMaxRetries %s: %v", val, err)
|
||||
return m, fmt.Errorf("invalid backOffMaxRetries %s: %v", val, err)
|
||||
}
|
||||
m.backOffMaxRetries = backOffMaxRetriesInt
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func parseMQTTMetaData(md bindings.Metadata, log logger.Logger) (*metadata, erro
|
|||
log.Warn("Metadata property 'qos' has been deprecated and is ignored; qos is set to 1")
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// isValidPEM validates the provided input has PEM formatted block.
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -34,7 +35,8 @@ import (
|
|||
// MQTT allows sending and receiving data to/from an MQTT broker.
|
||||
type MQTT struct {
|
||||
producer mqtt.Client
|
||||
metadata *metadata
|
||||
producerLock sync.RWMutex
|
||||
metadata metadata
|
||||
logger logger.Logger
|
||||
isSubscribed atomic.Bool
|
||||
readHandler bindings.Handler
|
||||
|
@ -51,28 +53,18 @@ func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
|
|||
}
|
||||
|
||||
// Init does MQTT connection parsing.
|
||||
func (m *MQTT) Init(metadata bindings.Metadata) error {
|
||||
mqttMeta, err := parseMQTTMetaData(metadata, m.logger)
|
||||
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
|
||||
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.metadata = mqttMeta
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// mqtt broker allows only one connection at a given time from a clientID.
|
||||
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
|
||||
p, err := m.connect(producerClientID, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Make the backoff configurable for constant or exponential
|
||||
b := backoff.NewConstantBackOff(5 * time.Second)
|
||||
m.backOff = backoff.WithContext(b, m.ctx)
|
||||
|
||||
m.producer = p
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -82,7 +74,42 @@ func (m *MQTT) Operations() []bindings.OperationKind {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *MQTT) getProducer() (mqtt.Client, error) {
|
||||
// Get the producer from the cache
|
||||
m.producerLock.RLock()
|
||||
producer := m.producer
|
||||
m.producerLock.RUnlock()
|
||||
if producer != nil {
|
||||
return producer, nil
|
||||
}
|
||||
|
||||
// Must create a new producer
|
||||
m.producerLock.Lock()
|
||||
defer m.producerLock.Unlock()
|
||||
|
||||
// Check again in case another goroutine created it in the meanwhile
|
||||
producer = m.producer
|
||||
if producer != nil {
|
||||
return producer, nil
|
||||
}
|
||||
|
||||
// mqtt broker allows only one connection at a given time from a clientID.
|
||||
producerClientID := fmt.Sprintf("%s-producer", m.metadata.clientID)
|
||||
p, err := m.connect(producerClientID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.producer = p
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
producer, err := m.getProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create producer connection: %w", err)
|
||||
}
|
||||
|
||||
// MQTT client Publish() has an internal race condition in the default autoreconnect config.
|
||||
// To mitigate sporadic failures on the Dapr side, this implementation retries 3 times at
|
||||
// a fixed 200ms interval. This is not configurable to keep this as an implementation detail
|
||||
|
@ -99,7 +126,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
|||
topic = m.metadata.topic
|
||||
}
|
||||
return nil, retry.NotifyRecover(func() (err error) {
|
||||
token := m.producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
||||
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
|
||||
defer cancel()
|
||||
select {
|
||||
|
@ -123,37 +150,6 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
|||
})
|
||||
}
|
||||
|
||||
func (m *MQTT) handleMessage(ctx context.Context, handler bindings.Handler, mqttMsg mqtt.Message) error {
|
||||
msg := bindings.ReadResponse{
|
||||
Data: mqttMsg.Payload(),
|
||||
Metadata: map[string]string{
|
||||
mqttTopic: mqttMsg.Topic(),
|
||||
},
|
||||
}
|
||||
|
||||
// paho.mqtt.golang requires that handlers never block or it can deadlock on client.Disconnect.
|
||||
// To ensure that the Dapr runtime does not hang on teardown on of the component, run the app's
|
||||
// handling code in a goroutine so that this handler function is always cancellable on Close().
|
||||
ch := make(chan error)
|
||||
go func(m *bindings.ReadResponse) {
|
||||
defer close(ch)
|
||||
_, err := handler(ctx, m)
|
||||
ch <- err
|
||||
}(&msg)
|
||||
|
||||
select {
|
||||
case handlerErr := <-ch:
|
||||
if handlerErr != nil {
|
||||
return handlerErr
|
||||
}
|
||||
mqttMsg.Ack()
|
||||
return nil
|
||||
case <-m.ctx.Done():
|
||||
m.logger.Infof("Read context cancelled: %v", m.ctx.Err())
|
||||
return m.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
|
||||
if !m.isSubscribed.CompareAndSwap(false, true) {
|
||||
|
@ -313,7 +309,19 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
|
|||
err := retry.NotifyRecover(
|
||||
func() error {
|
||||
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
return m.handleMessage(ctx, m.readHandler, mqttMsg)
|
||||
_, err := m.readHandler(ctx, &bindings.ReadResponse{
|
||||
Data: mqttMsg.Payload(),
|
||||
Metadata: map[string]string{
|
||||
mqttTopic: mqttMsg.Topic(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ack the message on success
|
||||
mqttMsg.Ack()
|
||||
return nil
|
||||
},
|
||||
bo,
|
||||
func(err error, d time.Duration) {
|
||||
|
@ -350,10 +358,16 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
|
|||
}
|
||||
|
||||
func (m *MQTT) Close() error {
|
||||
m.producerLock.Lock()
|
||||
defer m.producerLock.Unlock()
|
||||
|
||||
// Canceling the context also causes Read to stop receiving messages
|
||||
m.cancel()
|
||||
|
||||
m.producer.Disconnect(200)
|
||||
if m.producer != nil {
|
||||
m.producer.Disconnect(200)
|
||||
m.producer = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue