diff --git a/cmd/notify-mailer/main.go b/cmd/notify-mailer/main.go index 5e61b8321..06434c0f3 100644 --- a/cmd/notify-mailer/main.go +++ b/cmd/notify-mailer/main.go @@ -13,6 +13,7 @@ import ( "sort" "strconv" "strings" + "sync" "text/template" "time" @@ -36,6 +37,7 @@ type mailer struct { recipients []recipient targetRange interval sleepInterval time.Duration + parallelSends uint } // interval defines a range of email addresses to send to in alphabetical order. @@ -154,56 +156,87 @@ func (m *mailer) run() error { m.log.Infof("Address %q was associated with the most recipients (%d)", mostRecipients, mostRecipientsLen) - conn, err := m.mailer.Connect() - if err != nil { - return err + type work struct { + index int + address string } - defer func() { _ = conn.Close() }() + var wg sync.WaitGroup + workChan := make(chan work, totalAddresses) startTime := m.clk.Now() sortedAddresses := sortAddresses(addressToRecipient) - var sent int - for i, address := range sortedAddresses { - if !m.targetRange.includes(address) { - m.log.Debugf("Address %q is outside of target range, skipping", address) - continue + if (m.targetRange.start != "" && m.targetRange.start > sortedAddresses[totalAddresses-1]) || + (m.targetRange.end != "" && m.targetRange.end < sortedAddresses[0]) { + return errors.New("Zero found addresses fall inside target range") + } + + go func(ch chan<- work) { + for i, address := range sortedAddresses { + ch <- work{i, address} + } + close(workChan) + }(workChan) + + if m.parallelSends < 1 { + m.parallelSends = 1 + } + + for senderNum := uint(0); senderNum < m.parallelSends; senderNum++ { + // For politeness' sake, don't open more than 1 new connection per + // second. + if senderNum > 0 { + m.clk.Sleep(time.Second) } - err := policy.ValidEmail(address) + conn, err := m.mailer.Connect() if err != nil { - m.log.Infof("Skipping %q due to policy violation: %s", address, err) - continue + return fmt.Errorf("connecting parallel sender %d: %w", senderNum, err) } - recipients := addressToRecipient[address] - m.logStatus(address, i+1, totalAddresses, startTime) + wg.Add(1) + go func(conn bmail.Conn, ch <-chan work) { + defer wg.Done() + for w := range ch { + if !m.targetRange.includes(w.address) { + m.log.Debugf("Address %q is outside of target range, skipping", w.address) + continue + } - messageBody, err := m.makeMessageBody(recipients) - if err != nil { - m.log.Errf("Skipping %q due to templating error: %s", address, err) - continue - } + err := policy.ValidEmail(w.address) + if err != nil { + m.log.Infof("Skipping %q due to policy violation: %s", w.address, err) + continue + } - err = conn.SendMail([]string{address}, m.subject, messageBody) - if err != nil { - var badAddrErr bmail.BadAddressSMTPError - if errors.As(err, &badAddrErr) { - m.log.Errf("address %q was rejected by server: %s", address, err) - continue + recipients := addressToRecipient[w.address] + m.logStatus(w.address, w.index+1, totalAddresses, startTime) + + messageBody, err := m.makeMessageBody(recipients) + if err != nil { + m.log.Errf("Skipping %q due to templating error: %s", w.address, err) + continue + } + + err = conn.SendMail([]string{w.address}, m.subject, messageBody) + if err != nil { + var badAddrErr bmail.BadAddressSMTPError + if errors.As(err, &badAddrErr) { + m.log.Errf("address %q was rejected by server: %s", w.address, err) + continue + } + m.log.AuditErrf("while sending mail (%d) of (%d) to address %q: %s", + w.index, len(sortedAddresses), w.address, err) + } + + m.clk.Sleep(m.sleepInterval) } - return fmt.Errorf("while sending mail (%d) of (%d) to address %q: %s", - i, len(sortedAddresses), address, err) - } - - sent++ - m.clk.Sleep(m.sleepInterval) + conn.Close() + }(conn, workChan) } + wg.Wait() - if sent == 0 { - return errors.New("0 messages sent, check recipients or configured interval") - } return nil } @@ -484,6 +517,7 @@ func main() { bodyFile := flag.String("body", "", "File containing the email body in Golang template format.") dryRun := flag.Bool("dryRun", true, "Whether to do a dry run.") sleep := flag.Duration("sleep", 500*time.Millisecond, "How long to sleep between emails.") + parallelSends := flag.Uint("parallelSends", 1, "How many parallel goroutines should process emails") start := flag.String("start", "", "Alphabetically lowest email address to include.") end := flag.String("end", "\xFF", "Alphabetically highest email address (exclusive).") reconnBase := flag.Duration("reconnectBase", 1*time.Second, "Base sleep duration between reconnect attempts") @@ -498,8 +532,7 @@ func main() { // Validate required args. flag.Parse() - if *from == "" || *subject == "" || *bodyFile == "" || *configFile == "" || - *recipientListFile == "" { + if *from == "" || *subject == "" || *bodyFile == "" || *configFile == "" || *recipientListFile == "" { flag.Usage() os.Exit(1) } @@ -571,6 +604,7 @@ func main() { end: *end, }, sleepInterval: *sleep, + parallelSends: *parallelSends, } err = m.run() diff --git a/cmd/notify-mailer/main_test.go b/cmd/notify-mailer/main_test.go index 660ca7d99..11039b6c3 100644 --- a/cmd/notify-mailer/main_test.go +++ b/cmd/notify-mailer/main_test.go @@ -278,12 +278,14 @@ func TestSleepInterval(t *testing.T) { dbMap := mockEmailResolver{} tmpl := template.Must(template.New("letter").Parse("an email body")) recipients := []recipient{{id: 1}, {id: 2}, {id: 3}} - // Set up a mock mailer that sleeps for `sleepLen` seconds + // Set up a mock mailer that sleeps for `sleepLen` seconds and only has one + // goroutine to process results m := &mailer{ log: blog.UseMock(), mailer: mc, emailTemplate: tmpl, sleepInterval: sleepLen * time.Second, + parallelSends: 1, targetRange: interval{start: "", end: "\xFF"}, clk: newFakeClock(t), recipients: recipients, @@ -431,6 +433,52 @@ func TestMailIntervals(t *testing.T) { }, mc.Messages[1]) } +func TestParallelism(t *testing.T) { + const testSubject = "Test Subject" + dbMap := mockEmailResolver{} + + tmpl := template.Must(template.New("letter").Parse("an email body")) + recipients := []recipient{{id: 1}, {id: 2}, {id: 3}, {id: 4}} + + mc := &mocks.Mailer{} + + // Create a mailer with 10 parallel workers. + m := &mailer{ + log: blog.UseMock(), + mailer: mc, + dbMap: dbMap, + subject: testSubject, + recipients: recipients, + emailTemplate: tmpl, + targetRange: interval{end: "\xFF"}, + sleepInterval: 0, + parallelSends: 10, + clk: newFakeClock(t), + } + + mc.Clear() + err := m.run() + test.AssertNotError(t, err, "run() produced an error") + + // The fake clock should have advanced 9 seconds, one for each parallel + // goroutine after the first doing its polite 1-second sleep at startup. + expectedEnd := newFakeClock(t) + expectedEnd.Add(9 * time.Second) + test.AssertEquals(t, m.clk.Now(), expectedEnd.Now()) + + // A message should have been sent to all four addresses. + test.AssertEquals(t, len(mc.Messages), 4) + expectedAddresses := []string{ + "example@letsencrypt.org", + "test-example-updated@letsencrypt.org", + "test-test-test@letsencrypt.org", + "example-example-example@letsencrypt.org", + } + for _, msg := range mc.Messages { + test.AssertSliceContains(t, expectedAddresses, msg.To) + } +} + func TestMessageContentStatic(t *testing.T) { // Create a mailer with fixed content const (