Fix non-gRPC process cleanup and exit (#6808)

Although #6771 significantly cleaned up how gRPC services stop and clean
up, it didn't make any changes to our HTTP servers or our non-server
(e.g. crl-updater, log-validator) processes. This change finishes the
work.

Add a new helper method cmd.WaitForSignal, which simply blocks until one
of the three signals we care about is received. This easily replaces all
calls to cmd.CatchSignals which passed `nil` as the callback argument,
with the added advantage that it doesn't call os.Exit() and therefore
allows deferred cleanup functions to execute. This new function is
intended to be the last line of main(), allowing the whole process to
exit once it returns.

Reimplement cmd.CatchSignals as a thin wrapper around cmd.WaitForSignal,
but with the added callback functionality. Also remove the os.Exit()
call from CatchSignals, so that the main goroutine is allowed to finish
whatever it's doing, call deferred functions, and exit naturally.

Update all of our non-gRPC binaries to use one of these two functions.
The vast majority use WaitForSignal, as they run their main processing
loop in a background goroutine. A few (particularly those that can run
either in run-once or in daemonized mode) still use CatchSignals, since
their primary processing happens directly on the main goroutine.

The changes to //test/load-generator are the most invasive, simply
because that binary needed to have a context plumbed into it for proper
cancellation, but it already had a custom struct type named "context"
which needed to be renamed to avoid shadowing.

Fixes https://github.com/letsencrypt/boulder/issues/6794
This commit is contained in:
Aaron Gable 2023-04-14 13:22:56 -07:00 committed by GitHub
parent 98fa0f07b4
commit bd1d27b8e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 210 additions and 171 deletions

View File

@ -535,9 +535,6 @@ func main() {
}
}()
// The gosec linter complains that ReadHeaderTimeout is not set. That's fine,
// because that field inherits its value from ReadTimeout.
////nolint:gosec
tlsSrv := http.Server{
ReadTimeout: 30 * time.Second,
WriteTimeout: 120 * time.Second,
@ -555,20 +552,18 @@ func main() {
}()
}
done := make(chan bool)
go cmd.CatchSignals(logger, func() {
// When main is ready to exit (because it has received a shutdown signal),
// gracefully shutdown the servers. Calling these shutdown functions causes
// ListenAndServe() and ListenAndServeTLS() to immediately return, then waits
// for any lingering connection-handling goroutines to finish their work.
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration)
defer cancel()
_ = srv.Shutdown(ctx)
_ = tlsSrv.Shutdown(ctx)
done <- true
})
}()
// https://godoc.org/net/http#Server.Shutdown:
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit and
// waits instead for Shutdown to return.
<-done
cmd.WaitForSignal()
}
func init() {

View File

@ -2,6 +2,7 @@ package notmain
import (
"context"
"errors"
"flag"
"os"
"time"
@ -165,15 +166,17 @@ func main() {
cmd.FailOnError(err, "Failed to create crl-updater")
ctx, cancel := context.WithCancel(context.Background())
go cmd.CatchSignals(logger, cancel)
go cmd.CatchSignals(cancel)
if *runOnce {
err = u.Tick(ctx, clk.Now())
cmd.FailOnError(err, "")
if err != nil && !errors.Is(err, context.Canceled) {
cmd.FailOnError(err, "")
}
} else {
err = u.Run(ctx)
if err != nil {
logger.Err(err.Error())
if err != nil && !errors.Is(err, context.Canceled) {
cmd.FailOnError(err, "")
}
}
}

View File

@ -803,6 +803,11 @@ func main() {
defer logger.AuditPanic()
logger.Info(cmd.VersionString())
if *daemon && c.Mailer.Frequency.Duration == 0 {
fmt.Fprintln(os.Stderr, "mailer.frequency is not set in the JSON config")
os.Exit(1)
}
if *certLimit > 0 {
c.Mailer.CertLimit = *certLimit
}
@ -913,31 +918,26 @@ func main() {
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go cmd.CatchSignals(logger, func() {
cancel()
select {} // wait for the `findExpiringCertificates` calls below to exit
})
go cmd.CatchSignals(cancel)
if *daemon {
if c.Mailer.Frequency.Duration == 0 {
fmt.Fprintln(os.Stderr, "mailer.Frequency is not set in the JSON config")
os.Exit(1)
}
t := time.NewTicker(c.Mailer.Frequency.Duration)
for {
select {
case <-t.C:
err = m.findExpiringCertificates(ctx)
cmd.FailOnError(err, "expiration-mailer has failed")
if err != nil && !errors.Is(err, context.Canceled) {
cmd.FailOnError(err, "expiration-mailer has failed")
}
case <-ctx.Done():
os.Exit(0)
break
}
}
} else {
err = m.findExpiringCertificates(ctx)
cmd.FailOnError(err, "expiration-mailer has failed")
if err != nil && !errors.Is(err, context.Canceled) {
cmd.FailOnError(err, "expiration-mailer has failed")
}
}
}

View File

@ -199,7 +199,7 @@ func main() {
tailers = append(tailers, t)
}
cmd.CatchSignals(logger, func() {
defer func() {
for _, t := range tailers {
// The tail module seems to have a race condition that will generate
// errors like this on shutdown:
@ -211,7 +211,9 @@ func main() {
_ = t.Stop()
t.Cleanup()
}
})
}()
cmd.WaitForSignal()
}
func init() {

View File

@ -227,25 +227,24 @@ as generated by Boulder's ceremony command.
Handler: m,
}
done := make(chan bool)
go cmd.CatchSignals(logger, func() {
ctx, cancel := context.WithTimeout(context.Background(),
c.OCSPResponder.ShutdownStopTimeout.Duration)
defer cancel()
_ = srv.Shutdown(ctx)
done <- true
})
err = srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
cmd.FailOnError(err, "Running HTTP server")
}
// https://godoc.org/net/http#Server.Shutdown:
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit and
// waits instead for Shutdown to return.
<-done
// When main is ready to exit (because it has received a shutdown signal),
// gracefully shutdown the servers. Calling these shutdown functions causes
// ListenAndServe() to immediately return, cleaning up the server goroutines
// as well, then waits for any lingering connection-handing goroutines to
// finish and clean themselves up.
defer func() {
ctx, cancel := context.WithTimeout(context.Background(),
c.OCSPResponder.ShutdownStopTimeout.Duration)
defer cancel()
srv.Shutdown(ctx)
}()
cmd.WaitForSignal()
}
// ocspMux partially implements the interface defined for http.ServeMux but doesn't implement

View File

@ -418,18 +418,26 @@ func VersionString() string {
return fmt.Sprintf("Versions: %s=(%s %s) Golang=(%s) BuildHost=(%s)", command(), core.GetBuildID(), core.GetBuildTime(), runtime.Version(), core.GetBuildHost())
}
// CatchSignals catches SIGTERM, SIGINT, SIGHUP and executes a callback
// method before exiting
func CatchSignals(logger blog.Logger, callback func()) {
// CatchSignals blocks until a SIGTERM, SIGINT, or SIGHUP is received, then
// executes the given callback. The callback should not block, it should simply
// signal other goroutines (particularly the main goroutine) to clean themselves
// up and exit. This function is intended to be called in its own goroutine,
// while the main goroutine waits for an indication that the other goroutines
// have exited cleanly.
func CatchSignals(callback func()) {
WaitForSignal()
callback()
}
// WaitForSignal blocks until a SIGTERM, SIGINT, or SIGHUP is received. It then
// returns, allowing execution to resume, generally allowing a main() function
// to return and trigger and deferred cleanup functions. This function is
// intended to be called directly from the main goroutine, while a gRPC or HTTP
// server runs in a background goroutine.
func WaitForSignal() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
signal.Notify(sigChan, syscall.SIGHUP)
<-sigChan
if callback != nil {
callback()
}
os.Exit(0)
}

View File

@ -5,10 +5,7 @@ import (
"errors"
"fmt"
"net"
"os"
"os/signal"
"strings"
"syscall"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/jmhodges/clock"
@ -175,15 +172,10 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R
// Start a goroutine which listens for a termination signal, and then
// gracefully stops the gRPC server. This in turn causes the start() function
// to exit, allowing its caller (generally a main() function) to exit.
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
signal.Notify(sigChan, syscall.SIGHUP)
<-sigChan
go cmd.CatchSignals(func() {
healthSrv.Shutdown()
server.GracefulStop()
}()
})
return start, nil
}

View File

@ -1,13 +1,14 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"sync"
"time"
"github.com/letsencrypt/boulder/akamai"
"github.com/letsencrypt/boulder/cmd"
@ -92,9 +93,23 @@ func main() {
w.Write(resp)
})
// The gosec linter complains that timeouts cannot be set here. That's fine,
// because this is test-only code.
////nolint:gosec
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
cmd.CatchSignals(nil, nil)
s := http.Server{
ReadTimeout: 30 * time.Second,
Addr: *listenAddr,
}
go func() {
err := s.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
cmd.FailOnError(err, "Running TLS server")
}
}()
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
s.Shutdown(ctx)
}()
cmd.WaitForSignal()
}

View File

@ -257,5 +257,5 @@ func main() {
for _, p := range c.Personalities {
go runPersonality(p)
}
cmd.CatchSignals(nil, nil)
cmd.WaitForSignal()
}

View File

@ -31,7 +31,7 @@ import (
var (
// stringToOperation maps a configured plan action to a function that can
// operate on a state/context.
stringToOperation = map[string]func(*State, *context) error{
stringToOperation = map[string]func(*State, *acmeCache) error{
"newAccount": newAccount,
"getAccount": getAccount,
"newOrder": newOrder,
@ -60,8 +60,8 @@ type OrderJSON struct {
}
// getAccount takes a randomly selected v2 account from `state.accts` and puts it
// into `ctx.acct`. The context `nonceSource` is also populated as convenience.
func getAccount(s *State, ctx *context) error {
// into `c.acct`. The context `nonceSource` is also populated as convenience.
func getAccount(s *State, c *acmeCache) error {
s.rMu.RLock()
defer s.rMu.RUnlock()
@ -71,8 +71,8 @@ func getAccount(s *State, ctx *context) error {
}
// Select a random account from the state and put it into the context
ctx.acct = s.accts[mrand.Intn(len(s.accts))]
ctx.ns = &nonceSource{s: s}
c.acct = s.accts[mrand.Intn(len(s.accts))]
c.ns = &nonceSource{s: s}
return nil
}
@ -81,11 +81,11 @@ func getAccount(s *State, ctx *context) error {
// then `newAccount` puts an existing account from the state into the context,
// otherwise it creates a new account and puts it into both the state and the
// context.
func newAccount(s *State, ctx *context) error {
func newAccount(s *State, c *acmeCache) error {
// Check the max regs and if exceeded, just return an existing account instead
// of creating a new one.
if s.maxRegs != 0 && s.numAccts() >= s.maxRegs {
return getAccount(s, ctx)
return getAccount(s, c)
}
// Create a random signing key
@ -93,10 +93,10 @@ func newAccount(s *State, ctx *context) error {
if err != nil {
return err
}
ctx.acct = &account{
c.acct = &account{
key: signKey,
}
ctx.ns = &nonceSource{s: s}
c.ns = &nonceSource{s: s}
// Prepare an account registration message body
reqBody := struct {
@ -117,7 +117,7 @@ func newAccount(s *State, ctx *context) error {
// Sign the new account registration body using a JWS with an embedded JWK
// because we do not have a key ID from the server yet.
newAccountURL := s.directory.EndpointURL(acme.NewAccountEndpoint)
jws, err := ctx.signEmbeddedV2Request(reqBodyStr, newAccountURL)
jws, err := c.signEmbeddedV2Request(reqBodyStr, newAccountURL)
if err != nil {
return err
}
@ -126,7 +126,7 @@ func newAccount(s *State, ctx *context) error {
resp, err := s.post(
newAccountURL,
bodyBuf,
ctx.ns,
c.ns,
string(acme.NewAccountEndpoint),
http.StatusCreated)
if err != nil {
@ -140,10 +140,10 @@ func newAccount(s *State, ctx *context) error {
if locHeader == "" {
return fmt.Errorf("%s, bad response - no Location header with account ID", newAccountURL)
}
ctx.acct.id = locHeader
c.acct.id = locHeader
// Add the account to the state
s.addAccount(ctx.acct)
s.addAccount(c.acct)
return nil
}
@ -160,7 +160,7 @@ func randDomain(base string) string {
// newOrder creates a new pending order object for a random set of domains using
// the context's account.
func newOrder(s *State, ctx *context) error {
func newOrder(s *State, c *acmeCache) error {
// Pick a random number of names within the constraints of the maxNamesPerCert
// parameter
orderSize := 1 + mrand.Intn(s.maxNamesPerCert-1)
@ -187,7 +187,7 @@ func newOrder(s *State, ctx *context) error {
// Sign the new order request with the context account's key/key ID
newOrderURL := s.directory.EndpointURL(acme.NewOrderEndpoint)
jws, err := ctx.signKeyIDV2Request(initOrderStr, newOrderURL)
jws, err := c.signKeyIDV2Request(initOrderStr, newOrderURL)
if err != nil {
return err
}
@ -196,7 +196,7 @@ func newOrder(s *State, ctx *context) error {
resp, err := s.post(
newOrderURL,
bodyBuf,
ctx.ns,
c.ns,
string(acme.NewOrderEndpoint),
http.StatusCreated)
if err != nil {
@ -223,24 +223,24 @@ func newOrder(s *State, ctx *context) error {
orderJSON.URL = orderURL
// Store the pending order in the context
ctx.pendingOrders = append(ctx.pendingOrders, &orderJSON)
c.pendingOrders = append(c.pendingOrders, &orderJSON)
return nil
}
// popPendingOrder *removes* a random pendingOrder from the context, returning
// it.
func popPendingOrder(ctx *context) *OrderJSON {
orderIndex := mrand.Intn(len(ctx.pendingOrders))
order := ctx.pendingOrders[orderIndex]
ctx.pendingOrders = append(ctx.pendingOrders[:orderIndex], ctx.pendingOrders[orderIndex+1:]...)
func popPendingOrder(c *acmeCache) *OrderJSON {
orderIndex := mrand.Intn(len(c.pendingOrders))
order := c.pendingOrders[orderIndex]
c.pendingOrders = append(c.pendingOrders[:orderIndex], c.pendingOrders[orderIndex+1:]...)
return order
}
// getAuthorization fetches an authorization by GET-ing the provided URL. It
// records the latency and result of the GET operation in the state.
func getAuthorization(s *State, ctx *context, url string) (*core.Authorization, error) {
func getAuthorization(s *State, c *acmeCache, url string) (*core.Authorization, error) {
latencyTag := "/acme/authz/{ID}"
resp, err := postAsGet(s, ctx, url, latencyTag)
resp, err := postAsGet(s, c, url, latencyTag)
// If there was an error, note the state and return
if err != nil {
return nil, fmt.Errorf("%s bad response: %s", url, err)
@ -269,7 +269,7 @@ func getAuthorization(s *State, ctx *context, url string) (*core.Authorization,
// HTTP-01 challenge using the context's account and the state's challenge
// server. Aftering POSTing the authorization's HTTP-01 challenge the
// authorization will be polled waiting for a state change.
func completeAuthorization(authz *core.Authorization, s *State, ctx *context) error {
func completeAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
// Skip if the authz isn't pending
if authz.Status != core.StatusPending {
return nil
@ -283,7 +283,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
}
// Compute the key authorization from the context account's key
jwk := &jose.JSONWebKey{Key: &ctx.acct.key.PublicKey}
jwk := &jose.JSONWebKey{Key: &c.acct.key.PublicKey}
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return err
@ -311,7 +311,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
}
// Prepare the Challenge POST body
jws, err := ctx.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
jws, err := c.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
if err != nil {
return err
}
@ -320,7 +320,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
resp, err := s.post(
chalToSolve.URL,
requestPayload,
ctx.ns,
c.ns,
"/acme/challenge/{ID}", // We want all challenge POST latencies to be grouped
http.StatusOK,
)
@ -337,7 +337,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
// Poll the authorization waiting for the challenge response to be recorded in
// a change of state. The polling may sleep and retry a few times if required
err = pollAuthorization(authz, s, ctx)
err = pollAuthorization(authz, s, c)
if err != nil {
return err
}
@ -351,11 +351,11 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
// be valid. If the status is invalid, or if three GETs do not produce the
// correct authorization state an error is returned. If no error is returned
// then the authorization is valid and ready.
func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error {
func pollAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
authzURL := authz.ID
for i := 0; i < 3; i++ {
// Fetch the authz by its URL
authz, err := getAuthorization(s, ctx, authzURL)
authz, err := getAuthorization(s, c, authzURL)
if err != nil {
return nil
}
@ -377,25 +377,25 @@ func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error
// authorization's HTTP-01 challenge using the context's account, and finally
// placing the now-ready-to-be-finalized order into the context's list of
// fulfilled orders.
func fulfillOrder(s *State, ctx *context) error {
func fulfillOrder(s *State, c *acmeCache) error {
// There must be at least one pending order in the context to fulfill
if len(ctx.pendingOrders) == 0 {
if len(c.pendingOrders) == 0 {
return errors.New("no pending orders to fulfill")
}
// Get an order to fulfill from the context
order := popPendingOrder(ctx)
order := popPendingOrder(c)
// Each of its authorizations need to be processed
for _, url := range order.Authorizations {
// Fetch the authz by its URL
authz, err := getAuthorization(s, ctx, url)
authz, err := getAuthorization(s, c, url)
if err != nil {
return err
}
// Complete the authorization by solving a challenge
err = completeAuthorization(authz, s, ctx)
err = completeAuthorization(authz, s, c)
if err != nil {
return err
}
@ -403,16 +403,16 @@ func fulfillOrder(s *State, ctx *context) error {
// Once all of the authorizations have been fulfilled the order is fulfilled
// and ready for future finalization.
ctx.fulfilledOrders = append(ctx.fulfilledOrders, order.URL)
c.fulfilledOrders = append(c.fulfilledOrders, order.URL)
return nil
}
// getOrder GETs an order by URL, returning an OrderJSON object. It tracks the
// latency of the GET operation in the provided state.
func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
func getOrder(s *State, c *acmeCache, url string) (*OrderJSON, error) {
latencyTag := "/acme/order/{ID}"
// POST-as-GET the order URL
resp, err := postAsGet(s, ctx, url, latencyTag)
resp, err := postAsGet(s, c, url, latencyTag)
// If there was an error, track that result
if err != nil {
return nil, fmt.Errorf("%s bad response: %s", url, err)
@ -440,10 +440,10 @@ func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
// valid such that a certificate URL for the order is known. Three attempts are
// made to check the order status, sleeping 3s between each. If these attempts
// expire without the status becoming valid an error is returned.
func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, error) {
func pollOrderForCert(order *OrderJSON, s *State, c *acmeCache) (*OrderJSON, error) {
for i := 0; i < 3; i++ {
// Fetch the order by its URL
order, err := getOrder(s, ctx, order.URL)
order, err := getOrder(s, c, order.URL)
if err != nil {
return nil, err
}
@ -463,10 +463,10 @@ func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, err
// popFulfilledOrder **removes** a fulfilled order from the context, returning
// it. Fulfilled orders have all of their authorizations satisfied.
func popFulfilledOrder(ctx *context) string {
orderIndex := mrand.Intn(len(ctx.fulfilledOrders))
order := ctx.fulfilledOrders[orderIndex]
ctx.fulfilledOrders = append(ctx.fulfilledOrders[:orderIndex], ctx.fulfilledOrders[orderIndex+1:]...)
func popFulfilledOrder(c *acmeCache) string {
orderIndex := mrand.Intn(len(c.fulfilledOrders))
order := c.fulfilledOrders[orderIndex]
c.fulfilledOrders = append(c.fulfilledOrders[:orderIndex], c.fulfilledOrders[orderIndex+1:]...)
return order
}
@ -475,15 +475,15 @@ func popFulfilledOrder(ctx *context) string {
// `certKey`. The order is then polled for the status to change to valid so that
// the certificate URL can be added to the context. The context's `certs` list
// is updated with the URL for the order's certificate.
func finalizeOrder(s *State, ctx *context) error {
func finalizeOrder(s *State, c *acmeCache) error {
// There must be at least one fulfilled order in the context
if len(ctx.fulfilledOrders) < 1 {
if len(c.fulfilledOrders) < 1 {
return errors.New("No fulfilled orders in the context ready to be finalized")
}
// Pop a fulfilled order to process, and then GET its contents
orderID := popFulfilledOrder(ctx)
order, err := getOrder(s, ctx, orderID)
orderID := popFulfilledOrder(c)
order, err := getOrder(s, c, orderID)
if err != nil {
return err
}
@ -519,7 +519,7 @@ func finalizeOrder(s *State, ctx *context) error {
)
// Sign the request body with the context's account key/keyID
jws, err := ctx.signKeyIDV2Request([]byte(request), finalizeURL)
jws, err := c.signKeyIDV2Request([]byte(request), finalizeURL)
if err != nil {
return err
}
@ -528,7 +528,7 @@ func finalizeOrder(s *State, ctx *context) error {
resp, err := s.post(
finalizeURL,
requestPayload,
ctx.ns,
c.ns,
"/acme/order/finalize", // We want all order finalizations to be grouped.
http.StatusOK,
)
@ -544,7 +544,7 @@ func finalizeOrder(s *State, ctx *context) error {
}
// Poll the order waiting for the certificate to be ready
completedOrder, err := pollOrderForCert(order, s, ctx)
completedOrder, err := pollOrderForCert(order, s, c)
if err != nil {
return err
}
@ -556,8 +556,8 @@ func finalizeOrder(s *State, ctx *context) error {
}
// Append the certificate URL into the context's list of certificates
ctx.certs = append(ctx.certs, certURL)
ctx.finalizedOrders = append(ctx.finalizedOrders, order.URL)
c.certs = append(c.certs, certURL)
c.finalizedOrders = append(c.finalizedOrders, order.URL)
return nil
}
@ -567,27 +567,27 @@ func finalizeOrder(s *State, ctx *context) error {
// responsible for closing the HTTP response body.
//
// See RFC 8555 Section 6.3 for more information on POST-as-GET requests.
func postAsGet(s *State, ctx *context, url string, latencyTag string) (*http.Response, error) {
func postAsGet(s *State, c *acmeCache, url string, latencyTag string) (*http.Response, error) {
// Create the POST-as-GET request JWS
jws, err := ctx.signKeyIDV2Request([]byte(""), url)
jws, err := c.signKeyIDV2Request([]byte(""), url)
if err != nil {
return nil, err
}
requestPayload := []byte(jws.FullSerialize())
return s.post(url, requestPayload, ctx.ns, latencyTag, http.StatusOK)
return s.post(url, requestPayload, c.ns, latencyTag, http.StatusOK)
}
func popCertificate(ctx *context) string {
certIndex := mrand.Intn(len(ctx.certs))
certURL := ctx.certs[certIndex]
ctx.certs = append(ctx.certs[:certIndex], ctx.certs[certIndex+1:]...)
func popCertificate(c *acmeCache) string {
certIndex := mrand.Intn(len(c.certs))
certURL := c.certs[certIndex]
c.certs = append(c.certs[:certIndex], c.certs[certIndex+1:]...)
return certURL
}
func getCert(s *State, ctx *context, url string) ([]byte, error) {
func getCert(s *State, c *acmeCache, url string) ([]byte, error) {
latencyTag := "/acme/cert/{serial}"
resp, err := postAsGet(s, ctx, url, latencyTag)
resp, err := postAsGet(s, c, url, latencyTag)
if err != nil {
return nil, fmt.Errorf("%s bad response: %s", url, err)
}
@ -599,8 +599,8 @@ func getCert(s *State, ctx *context, url string) ([]byte, error) {
// and sends a revocation request for the certificate to the ACME server.
// The revocation request is signed with the account key rather than the certificate
// key.
func revokeCertificate(s *State, ctx *context) error {
if len(ctx.certs) < 1 {
func revokeCertificate(s *State, c *acmeCache) error {
if len(c.certs) < 1 {
return errors.New("No certificates in the context that can be revoked")
}
@ -608,8 +608,8 @@ func revokeCertificate(s *State, ctx *context) error {
return nil
}
certURL := popCertificate(ctx)
certPEM, err := getCert(s, ctx, certURL)
certURL := popCertificate(c)
certPEM, err := getCert(s, c, certURL)
if err != nil {
return err
}
@ -630,7 +630,7 @@ func revokeCertificate(s *State, ctx *context) error {
revokeURL := s.directory.EndpointURL(acme.RevokeCertEndpoint)
// TODO(roland): randomly use the certificate key to sign the request instead of
// the account key
jws, err := ctx.signKeyIDV2Request(revokeJSON, revokeURL)
jws, err := c.signKeyIDV2Request(revokeJSON, revokeURL)
if err != nil {
return err
}
@ -639,7 +639,7 @@ func revokeCertificate(s *State, ctx *context) error {
resp, err := s.post(
revokeURL,
requestPayload,
ctx.ns,
c.ns,
"/acme/revoke-cert",
http.StatusOK,
)

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
@ -118,9 +119,11 @@ func main() {
"HTTPOneAddrs, TLSALPNOneAddrs or DNSAddrs\n")
}
go cmd.CatchSignals(nil, nil)
ctx, cancel := context.WithCancel(context.Background())
go cmd.CatchSignals(cancel)
err = s.Run(
ctx,
config.HTTPOneAddrs,
config.TLSALPNOneAddrs,
config.DNSAddrs,

View File

@ -2,6 +2,7 @@ package main
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@ -15,14 +16,12 @@ import (
"net"
"net/http"
"os"
"os/signal"
"reflect"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"gopkg.in/go-jose/go-jose.v2"
@ -54,7 +53,7 @@ func (acct *account) update(finalizedOrders, certs []string) {
acct.certs = append(acct.certs, certs...)
}
type context struct {
type acmeCache struct {
// The current V2 account (may be nil for legacy load generation)
acct *account
// Pending orders waiting for authorization challenge validation
@ -70,12 +69,12 @@ type context struct {
ns *nonceSource
}
// signEmbeddedV2Request signs the provided request data using the context's
// signEmbeddedV2Request signs the provided request data using the acmeCache's
// account's private key. The provided URL is set as a protected header per ACME
// v2 JWS standards. The resulting JWS contains an **embedded** JWK - this makes
// this function primarily applicable to new account requests where no key ID is
// known.
func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
func (c *acmeCache) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
// Create a signing key for the account's private key
signingKey := jose.SigningKey{
Key: c.acct.key,
@ -101,13 +100,13 @@ func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebS
return signed, nil
}
// signKeyIDV2Request signs the provided request data using the context's
// signKeyIDV2Request signs the provided request data using the acmeCache's
// account's private key. The provided URL is set as a protected header per ACME
// v2 JWS standards. The resulting JWS contains a Key ID header that is
// populated using the context's account's ID. This is the default JWS signing
// populated using the acmeCache's account's ID. This is the default JWS signing
// style for ACME v2 requests and should be used everywhere but where the key ID
// is unknown (e.g. new-account requests where an account doesn't exist yet).
func (c *context) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
func (c *acmeCache) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
// Create a JWK with the account's private key and key ID
jwk := &jose.JSONWebKey{
Key: c.acct.key,
@ -168,7 +167,7 @@ type State struct {
realIP string
certKey *ecdsa.PrivateKey
operations []func(*State, *context) error
operations []func(*State, *acmeCache) error
rMu sync.RWMutex
@ -349,6 +348,7 @@ func New(
// Run runs the WFE load-generator
func (s *State) Run(
ctx context.Context,
httpOneAddrs []string,
tlsALPNOneAddrs []string,
dnsAddrs []string,
@ -387,8 +387,6 @@ func (s *State) Run(
// Run sending loop
stop := make(chan bool, 1)
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
fmt.Println("[+] Beginning execution plan")
i := int64(0)
go func() {
@ -429,8 +427,8 @@ func (s *State) Run(
select {
case <-time.After(p.Runtime):
fmt.Println("[+] Execution plan finished")
case sig := <-sigs:
fmt.Printf("[!] Execution plan interrupted: %s caught\n", sig.String())
case <-ctx.Done():
fmt.Println("[!] Execution plan cancelled")
}
stop <- true
fmt.Println("[+] Waiting for pending flows to finish before killing challenge server")
@ -586,19 +584,19 @@ func (s *State) addAccount(acct *account) {
func (s *State) sendCall() {
defer s.wg.Done()
ctx := &context{}
c := &acmeCache{}
for _, op := range s.operations {
err := op(s, ctx)
err := op(s, c)
if err != nil {
method := runtime.FuncForPC(reflect.ValueOf(op).Pointer()).Name()
fmt.Printf("[FAILED] %s: %s\n", method, err)
break
}
}
// If the context's V2 account isn't nil, update it based on the context's
// If the acmeCache's V2 account isn't nil, update it based on the cache's
// finalizedOrders and certs.
if ctx.acct != nil {
ctx.acct.update(ctx.finalizedOrders, ctx.certs)
if c.acct != nil {
c.acct.update(c.finalizedOrders, c.certs)
}
}

View File

@ -3,6 +3,7 @@ package main
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"flag"
"fmt"
@ -186,11 +187,19 @@ scan:
}
}
func (srv *mailSrv) serveSMTP(l net.Listener) error {
func (srv *mailSrv) serveSMTP(ctx context.Context, l net.Listener) error {
for {
conn, err := l.Accept()
if err != nil {
return err
// If the accept call returned an error because the listener has been
// closed, then the context should have been canceled too. In that case,
// ignore the error.
select {
case <-ctx.Done():
return nil
default:
return err
}
}
go srv.handleConn(conn)
}
@ -233,10 +242,10 @@ func main() {
}
}()
go cmd.CatchSignals(nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err = srv.serveSMTP(l)
if err != nil {
log.Fatalln(err, "Failed to accept connection")
}
go cmd.FailOnError(srv.serveSMTP(ctx, l), "Failed to accept connection")
cmd.WaitForSignal()
}

View File

@ -1,12 +1,13 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
"sync"
"time"
"github.com/letsencrypt/boulder/cmd"
"github.com/letsencrypt/boulder/core"
@ -93,9 +94,23 @@ func main() {
http.HandleFunc("/clear", srv.handleClear)
http.HandleFunc("/query", srv.handleQuery)
// The gosec linter complains that timeouts cannot be set here. That's fine,
// because this is test-only code.
////nolint:gosec
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
cmd.CatchSignals(nil, nil)
s := http.Server{
ReadTimeout: 30 * time.Second,
Addr: *listenAddr,
}
go func() {
err := s.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
cmd.FailOnError(err, "Running TLS server")
}
}()
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
s.Shutdown(ctx)
}()
cmd.WaitForSignal()
}