Fix bindings.mqtt deadlock on Disconnect()

Update bindings.mqtt component to avoid deadlocking on Close() when the
app-provided message handler is blocked.

- Run the app-provided message handler in a separate goroutine so that
  the wrapper handler function it passes to the mqtt library can be
  canceled.
- Cancel the read context to terminate any in flight message handling
  before calling Disconnect() in Close()
This commit is contained in:
Simon Leet 2021-10-26 21:42:15 +00:00
parent dfb9e90ede
commit 7fedcfc751
1 changed files with 31 additions and 9 deletions

View File

@ -196,6 +196,32 @@ func (m *MQTT) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, er
return nil, nil
}
func (m *MQTT) handleMessage(handler func(*bindings.ReadResponse) ([]byte, error), mqttMsg mqtt.Message) error {
msg := bindings.ReadResponse{Data: mqttMsg.Payload()}
// paho.mqtt.golang requires that handlers never block or it can deadlock on client.Disconnect.
// To ensure that the Dapr runtime does not hang on teardown on of the component, run the app's
// handling code in a goroutine so that this handler function is always cancellable on Close().
ch := make(chan error)
go func(m *bindings.ReadResponse) {
defer close(ch)
_, err := handler(m)
ch <- err
}(&msg)
select {
case handlerErr := <-ch:
if handlerErr != nil {
return handlerErr
}
mqttMsg.Ack()
return nil
case <-m.ctx.Done():
m.logger.Infof("Read context cancelled: %v", m.ctx.Err())
return m.ctx.Err()
}
}
func (m *MQTT) Read(handler func(*bindings.ReadResponse) ([]byte, error)) error {
sigterm := make(chan os.Signal, 1)
signal.Notify(sigterm, os.Interrupt, syscall.SIGTERM)
@ -217,21 +243,14 @@ func (m *MQTT) Read(handler func(*bindings.ReadResponse) ([]byte, error)) error
m.logger.Debugf("mqtt subscribing to topic %s", m.metadata.topic)
token := m.consumer.Subscribe(m.metadata.topic, m.metadata.qos, func(client mqtt.Client, mqttMsg mqtt.Message) {
msg := bindings.ReadResponse{Data: mqttMsg.Payload()}
b := m.backOff
if m.metadata.backOffMaxRetries >= 0 {
b = backoff.WithMaxRetries(m.backOff, uint64(m.metadata.backOffMaxRetries))
}
if err := retry.NotifyRecover(func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
if _, err := handler(&msg); err != nil {
return err
}
mqttMsg.Ack()
return nil
return m.handleMessage(handler, mqttMsg)
}, b, func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying...", mqttMsg.Topic(), mqttMsg.MessageID())
}, func() {
@ -313,6 +332,9 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
}
func (m *MQTT) Close() error {
// Cancel any read callback handlers before Disconnect to prevent deadlocks.
m.cancel()
if m.consumer != nil {
m.consumer.Disconnect(1)
}