components-contrib/bindings/rabbitmq/rabbitmq.go

169 lines
3.3 KiB
Go

// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package rabbitmq
import (
"encoding/json"
"strconv"
"time"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
"github.com/streadway/amqp"
)
const (
rabbitMQQueueMessageTTLKey = "x-message-ttl"
)
// RabbitMQ allows sending/receiving data to/from RabbitMQ
type RabbitMQ struct {
connection *amqp.Connection
channel *amqp.Channel
metadata rabbitMQMetadata
logger logger.Logger
queue amqp.Queue
}
// Metadata is the rabbitmq config
type rabbitMQMetadata struct {
QueueName string `json:"queueName"`
Host string `json:"host"`
Durable bool `json:"durable,string"`
DeleteWhenUnused bool `json:"deleteWhenUnused,string"`
defaultQueueTTL *time.Duration
}
// NewRabbitMQ returns a new rabbitmq instance
func NewRabbitMQ(logger logger.Logger) *RabbitMQ {
return &RabbitMQ{logger: logger}
}
// Init does metadata parsing and connection creation
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
err := r.parseMetadata(metadata)
if err != nil {
return err
}
conn, err := amqp.Dial(r.metadata.Host)
if err != nil {
return err
}
ch, err := conn.Channel()
if err != nil {
return err
}
r.connection = conn
r.channel = ch
q, err := r.declareQueue()
if err != nil {
return err
}
r.queue = q
return nil
}
func (r *RabbitMQ) Write(req *bindings.WriteRequest) error {
pub := amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "text/plain",
Body: req.Data,
}
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
if err != nil {
return err
}
// The default time to live has been set in the queue
// We allow overriding on each call, by setting a value in request metadata
if ok {
// RabbitMQ expects the duration in ms
pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
}
err = r.channel.Publish("", r.metadata.QueueName, false, false, pub)
if err != nil {
return err
}
return nil
}
func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
b, err := json.Marshal(metadata.Properties)
if err != nil {
return err
}
var m rabbitMQMetadata
err = json.Unmarshal(b, &m)
if err != nil {
return err
}
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
if err != nil {
return err
}
if ok {
m.defaultQueueTTL = &ttl
}
r.metadata = m
return nil
}
func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
args := amqp.Table{}
if r.metadata.defaultQueueTTL != nil {
// Value in ms
ttl := *r.metadata.defaultQueueTTL / time.Millisecond
args[rabbitMQQueueMessageTTLKey] = int(ttl)
}
return r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, args)
}
func (r *RabbitMQ) Read(handler func(*bindings.ReadResponse) error) error {
msgs, err := r.channel.Consume(
r.queue.Name,
"",
false,
false,
false,
false,
nil,
)
if err != nil {
return err
}
forever := make(chan bool)
go func() {
for d := range msgs {
err := handler(&bindings.ReadResponse{
Data: d.Body,
})
if err == nil {
r.channel.Ack(d.DeliveryTag, false)
}
}
}()
<-forever
return nil
}