Parallelize notify-mailer (#6268)

Use the same pattern as was recently implemented in
expiration-mailer to parallelize notify-mailer. This should
significantly increase throughput when sending emails
to all subscribers.
This commit is contained in:
Aaron Gable 2022-08-02 16:18:01 -07:00 committed by GitHub
parent 0eec51f0b7
commit f5525ccd15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 37 deletions

View File

@ -13,6 +13,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"text/template"
"time"
@ -36,6 +37,7 @@ type mailer struct {
recipients []recipient
targetRange interval
sleepInterval time.Duration
parallelSends uint
}
// interval defines a range of email addresses to send to in alphabetical order.
@ -154,56 +156,87 @@ func (m *mailer) run() error {
m.log.Infof("Address %q was associated with the most recipients (%d)",
mostRecipients, mostRecipientsLen)
conn, err := m.mailer.Connect()
if err != nil {
return err
type work struct {
index int
address string
}
defer func() { _ = conn.Close() }()
var wg sync.WaitGroup
workChan := make(chan work, totalAddresses)
startTime := m.clk.Now()
sortedAddresses := sortAddresses(addressToRecipient)
var sent int
for i, address := range sortedAddresses {
if !m.targetRange.includes(address) {
m.log.Debugf("Address %q is outside of target range, skipping", address)
continue
if (m.targetRange.start != "" && m.targetRange.start > sortedAddresses[totalAddresses-1]) ||
(m.targetRange.end != "" && m.targetRange.end < sortedAddresses[0]) {
return errors.New("Zero found addresses fall inside target range")
}
go func(ch chan<- work) {
for i, address := range sortedAddresses {
ch <- work{i, address}
}
close(workChan)
}(workChan)
if m.parallelSends < 1 {
m.parallelSends = 1
}
for senderNum := uint(0); senderNum < m.parallelSends; senderNum++ {
// For politeness' sake, don't open more than 1 new connection per
// second.
if senderNum > 0 {
m.clk.Sleep(time.Second)
}
err := policy.ValidEmail(address)
conn, err := m.mailer.Connect()
if err != nil {
m.log.Infof("Skipping %q due to policy violation: %s", address, err)
continue
return fmt.Errorf("connecting parallel sender %d: %w", senderNum, err)
}
recipients := addressToRecipient[address]
m.logStatus(address, i+1, totalAddresses, startTime)
wg.Add(1)
go func(conn bmail.Conn, ch <-chan work) {
defer wg.Done()
for w := range ch {
if !m.targetRange.includes(w.address) {
m.log.Debugf("Address %q is outside of target range, skipping", w.address)
continue
}
messageBody, err := m.makeMessageBody(recipients)
if err != nil {
m.log.Errf("Skipping %q due to templating error: %s", address, err)
continue
}
err := policy.ValidEmail(w.address)
if err != nil {
m.log.Infof("Skipping %q due to policy violation: %s", w.address, err)
continue
}
err = conn.SendMail([]string{address}, m.subject, messageBody)
if err != nil {
var badAddrErr bmail.BadAddressSMTPError
if errors.As(err, &badAddrErr) {
m.log.Errf("address %q was rejected by server: %s", address, err)
continue
recipients := addressToRecipient[w.address]
m.logStatus(w.address, w.index+1, totalAddresses, startTime)
messageBody, err := m.makeMessageBody(recipients)
if err != nil {
m.log.Errf("Skipping %q due to templating error: %s", w.address, err)
continue
}
err = conn.SendMail([]string{w.address}, m.subject, messageBody)
if err != nil {
var badAddrErr bmail.BadAddressSMTPError
if errors.As(err, &badAddrErr) {
m.log.Errf("address %q was rejected by server: %s", w.address, err)
continue
}
m.log.AuditErrf("while sending mail (%d) of (%d) to address %q: %s",
w.index, len(sortedAddresses), w.address, err)
}
m.clk.Sleep(m.sleepInterval)
}
return fmt.Errorf("while sending mail (%d) of (%d) to address %q: %s",
i, len(sortedAddresses), address, err)
}
sent++
m.clk.Sleep(m.sleepInterval)
conn.Close()
}(conn, workChan)
}
wg.Wait()
if sent == 0 {
return errors.New("0 messages sent, check recipients or configured interval")
}
return nil
}
@ -484,6 +517,7 @@ func main() {
bodyFile := flag.String("body", "", "File containing the email body in Golang template format.")
dryRun := flag.Bool("dryRun", true, "Whether to do a dry run.")
sleep := flag.Duration("sleep", 500*time.Millisecond, "How long to sleep between emails.")
parallelSends := flag.Uint("parallelSends", 1, "How many parallel goroutines should process emails")
start := flag.String("start", "", "Alphabetically lowest email address to include.")
end := flag.String("end", "\xFF", "Alphabetically highest email address (exclusive).")
reconnBase := flag.Duration("reconnectBase", 1*time.Second, "Base sleep duration between reconnect attempts")
@ -498,8 +532,7 @@ func main() {
// Validate required args.
flag.Parse()
if *from == "" || *subject == "" || *bodyFile == "" || *configFile == "" ||
*recipientListFile == "" {
if *from == "" || *subject == "" || *bodyFile == "" || *configFile == "" || *recipientListFile == "" {
flag.Usage()
os.Exit(1)
}
@ -571,6 +604,7 @@ func main() {
end: *end,
},
sleepInterval: *sleep,
parallelSends: *parallelSends,
}
err = m.run()

View File

@ -278,12 +278,14 @@ func TestSleepInterval(t *testing.T) {
dbMap := mockEmailResolver{}
tmpl := template.Must(template.New("letter").Parse("an email body"))
recipients := []recipient{{id: 1}, {id: 2}, {id: 3}}
// Set up a mock mailer that sleeps for `sleepLen` seconds
// Set up a mock mailer that sleeps for `sleepLen` seconds and only has one
// goroutine to process results
m := &mailer{
log: blog.UseMock(),
mailer: mc,
emailTemplate: tmpl,
sleepInterval: sleepLen * time.Second,
parallelSends: 1,
targetRange: interval{start: "", end: "\xFF"},
clk: newFakeClock(t),
recipients: recipients,
@ -431,6 +433,52 @@ func TestMailIntervals(t *testing.T) {
}, mc.Messages[1])
}
func TestParallelism(t *testing.T) {
const testSubject = "Test Subject"
dbMap := mockEmailResolver{}
tmpl := template.Must(template.New("letter").Parse("an email body"))
recipients := []recipient{{id: 1}, {id: 2}, {id: 3}, {id: 4}}
mc := &mocks.Mailer{}
// Create a mailer with 10 parallel workers.
m := &mailer{
log: blog.UseMock(),
mailer: mc,
dbMap: dbMap,
subject: testSubject,
recipients: recipients,
emailTemplate: tmpl,
targetRange: interval{end: "\xFF"},
sleepInterval: 0,
parallelSends: 10,
clk: newFakeClock(t),
}
mc.Clear()
err := m.run()
test.AssertNotError(t, err, "run() produced an error")
// The fake clock should have advanced 9 seconds, one for each parallel
// goroutine after the first doing its polite 1-second sleep at startup.
expectedEnd := newFakeClock(t)
expectedEnd.Add(9 * time.Second)
test.AssertEquals(t, m.clk.Now(), expectedEnd.Now())
// A message should have been sent to all four addresses.
test.AssertEquals(t, len(mc.Messages), 4)
expectedAddresses := []string{
"example@letsencrypt.org",
"test-example-updated@letsencrypt.org",
"test-test-test@letsencrypt.org",
"example-example-example@letsencrypt.org",
}
for _, msg := range mc.Messages {
test.AssertSliceContains(t, expectedAddresses, msg.To)
}
}
func TestMessageContentStatic(t *testing.T) {
// Create a mailer with fixed content
const (