diff --git a/cmd/boulder-wfe2/main.go b/cmd/boulder-wfe2/main.go index 54393b3d0..942c9717d 100644 --- a/cmd/boulder-wfe2/main.go +++ b/cmd/boulder-wfe2/main.go @@ -535,9 +535,6 @@ func main() { } }() - // The gosec linter complains that ReadHeaderTimeout is not set. That's fine, - // because that field inherits its value from ReadTimeout. - ////nolint:gosec tlsSrv := http.Server{ ReadTimeout: 30 * time.Second, WriteTimeout: 120 * time.Second, @@ -555,20 +552,18 @@ func main() { }() } - done := make(chan bool) - go cmd.CatchSignals(logger, func() { + // When main is ready to exit (because it has received a shutdown signal), + // gracefully shutdown the servers. Calling these shutdown functions causes + // ListenAndServe() and ListenAndServeTLS() to immediately return, then waits + // for any lingering connection-handling goroutines to finish their work. + defer func() { ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration) defer cancel() _ = srv.Shutdown(ctx) _ = tlsSrv.Shutdown(ctx) - done <- true - }) + }() - // https://godoc.org/net/http#Server.Shutdown: - // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS - // immediately return ErrServerClosed. Make sure the program doesn't exit and - // waits instead for Shutdown to return. - <-done + cmd.WaitForSignal() } func init() { diff --git a/cmd/crl-updater/main.go b/cmd/crl-updater/main.go index b1ad91de1..c83484309 100644 --- a/cmd/crl-updater/main.go +++ b/cmd/crl-updater/main.go @@ -2,6 +2,7 @@ package notmain import ( "context" + "errors" "flag" "os" "time" @@ -165,15 +166,17 @@ func main() { cmd.FailOnError(err, "Failed to create crl-updater") ctx, cancel := context.WithCancel(context.Background()) - go cmd.CatchSignals(logger, cancel) + go cmd.CatchSignals(cancel) if *runOnce { err = u.Tick(ctx, clk.Now()) - cmd.FailOnError(err, "") + if err != nil && !errors.Is(err, context.Canceled) { + cmd.FailOnError(err, "") + } } else { err = u.Run(ctx) - if err != nil { - logger.Err(err.Error()) + if err != nil && !errors.Is(err, context.Canceled) { + cmd.FailOnError(err, "") } } } diff --git a/cmd/expiration-mailer/main.go b/cmd/expiration-mailer/main.go index d81ebed67..09589b56d 100644 --- a/cmd/expiration-mailer/main.go +++ b/cmd/expiration-mailer/main.go @@ -803,6 +803,11 @@ func main() { defer logger.AuditPanic() logger.Info(cmd.VersionString()) + if *daemon && c.Mailer.Frequency.Duration == 0 { + fmt.Fprintln(os.Stderr, "mailer.frequency is not set in the JSON config") + os.Exit(1) + } + if *certLimit > 0 { c.Mailer.CertLimit = *certLimit } @@ -913,31 +918,26 @@ func main() { } ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go cmd.CatchSignals(logger, func() { - cancel() - select {} // wait for the `findExpiringCertificates` calls below to exit - }) + go cmd.CatchSignals(cancel) if *daemon { - if c.Mailer.Frequency.Duration == 0 { - fmt.Fprintln(os.Stderr, "mailer.Frequency is not set in the JSON config") - os.Exit(1) - } t := time.NewTicker(c.Mailer.Frequency.Duration) for { select { case <-t.C: err = m.findExpiringCertificates(ctx) - cmd.FailOnError(err, "expiration-mailer has failed") + if err != nil && !errors.Is(err, context.Canceled) { + cmd.FailOnError(err, "expiration-mailer has failed") + } case <-ctx.Done(): - os.Exit(0) + break } } } else { err = m.findExpiringCertificates(ctx) - cmd.FailOnError(err, "expiration-mailer has failed") + if err != nil && !errors.Is(err, context.Canceled) { + cmd.FailOnError(err, "expiration-mailer has failed") + } } } diff --git a/cmd/log-validator/main.go b/cmd/log-validator/main.go index 66d58ade3..ab7c456b8 100644 --- a/cmd/log-validator/main.go +++ b/cmd/log-validator/main.go @@ -199,7 +199,7 @@ func main() { tailers = append(tailers, t) } - cmd.CatchSignals(logger, func() { + defer func() { for _, t := range tailers { // The tail module seems to have a race condition that will generate // errors like this on shutdown: @@ -211,7 +211,9 @@ func main() { _ = t.Stop() t.Cleanup() } - }) + }() + + cmd.WaitForSignal() } func init() { diff --git a/cmd/ocsp-responder/main.go b/cmd/ocsp-responder/main.go index ff250d06e..8227f2606 100644 --- a/cmd/ocsp-responder/main.go +++ b/cmd/ocsp-responder/main.go @@ -227,25 +227,24 @@ as generated by Boulder's ceremony command. Handler: m, } - done := make(chan bool) - go cmd.CatchSignals(logger, func() { - ctx, cancel := context.WithTimeout(context.Background(), - c.OCSPResponder.ShutdownStopTimeout.Duration) - defer cancel() - _ = srv.Shutdown(ctx) - done <- true - }) - err = srv.ListenAndServe() if err != nil && err != http.ErrServerClosed { cmd.FailOnError(err, "Running HTTP server") } - // https://godoc.org/net/http#Server.Shutdown: - // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS - // immediately return ErrServerClosed. Make sure the program doesn't exit and - // waits instead for Shutdown to return. - <-done + // When main is ready to exit (because it has received a shutdown signal), + // gracefully shutdown the servers. Calling these shutdown functions causes + // ListenAndServe() to immediately return, cleaning up the server goroutines + // as well, then waits for any lingering connection-handing goroutines to + // finish and clean themselves up. + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), + c.OCSPResponder.ShutdownStopTimeout.Duration) + defer cancel() + srv.Shutdown(ctx) + }() + + cmd.WaitForSignal() } // ocspMux partially implements the interface defined for http.ServeMux but doesn't implement diff --git a/cmd/shell.go b/cmd/shell.go index db0f9cbc3..e198d11a9 100644 --- a/cmd/shell.go +++ b/cmd/shell.go @@ -418,18 +418,26 @@ func VersionString() string { return fmt.Sprintf("Versions: %s=(%s %s) Golang=(%s) BuildHost=(%s)", command(), core.GetBuildID(), core.GetBuildTime(), runtime.Version(), core.GetBuildHost()) } -// CatchSignals catches SIGTERM, SIGINT, SIGHUP and executes a callback -// method before exiting -func CatchSignals(logger blog.Logger, callback func()) { +// CatchSignals blocks until a SIGTERM, SIGINT, or SIGHUP is received, then +// executes the given callback. The callback should not block, it should simply +// signal other goroutines (particularly the main goroutine) to clean themselves +// up and exit. This function is intended to be called in its own goroutine, +// while the main goroutine waits for an indication that the other goroutines +// have exited cleanly. +func CatchSignals(callback func()) { + WaitForSignal() + callback() +} + +// WaitForSignal blocks until a SIGTERM, SIGINT, or SIGHUP is received. It then +// returns, allowing execution to resume, generally allowing a main() function +// to return and trigger and deferred cleanup functions. This function is +// intended to be called directly from the main goroutine, while a gRPC or HTTP +// server runs in a background goroutine. +func WaitForSignal() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT) signal.Notify(sigChan, syscall.SIGHUP) - <-sigChan - if callback != nil { - callback() - } - - os.Exit(0) } diff --git a/grpc/server.go b/grpc/server.go index bf8825c49..fb176c2a2 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -5,10 +5,7 @@ import ( "errors" "fmt" "net" - "os" - "os/signal" "strings" - "syscall" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/jmhodges/clock" @@ -175,15 +172,10 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R // Start a goroutine which listens for a termination signal, and then // gracefully stops the gRPC server. This in turn causes the start() function // to exit, allowing its caller (generally a main() function) to exit. - go func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM) - signal.Notify(sigChan, syscall.SIGINT) - signal.Notify(sigChan, syscall.SIGHUP) - <-sigChan + go cmd.CatchSignals(func() { healthSrv.Shutdown() server.GracefulStop() - }() + }) return start, nil } diff --git a/test/akamai-test-srv/main.go b/test/akamai-test-srv/main.go index b9fd85c0f..e6af7e566 100644 --- a/test/akamai-test-srv/main.go +++ b/test/akamai-test-srv/main.go @@ -1,13 +1,14 @@ package main import ( + "context" "encoding/json" "flag" "fmt" "io" - "log" "net/http" "sync" + "time" "github.com/letsencrypt/boulder/akamai" "github.com/letsencrypt/boulder/cmd" @@ -92,9 +93,23 @@ func main() { w.Write(resp) }) - // The gosec linter complains that timeouts cannot be set here. That's fine, - // because this is test-only code. - ////nolint:gosec - go log.Fatal(http.ListenAndServe(*listenAddr, nil)) - cmd.CatchSignals(nil, nil) + s := http.Server{ + ReadTimeout: 30 * time.Second, + Addr: *listenAddr, + } + + go func() { + err := s.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + cmd.FailOnError(err, "Running TLS server") + } + }() + + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + s.Shutdown(ctx) + }() + + cmd.WaitForSignal() } diff --git a/test/ct-test-srv/main.go b/test/ct-test-srv/main.go index 0bb728d9e..564ad85f7 100644 --- a/test/ct-test-srv/main.go +++ b/test/ct-test-srv/main.go @@ -257,5 +257,5 @@ func main() { for _, p := range c.Personalities { go runPersonality(p) } - cmd.CatchSignals(nil, nil) + cmd.WaitForSignal() } diff --git a/test/load-generator/boulder-calls.go b/test/load-generator/boulder-calls.go index a98f02eaf..52690dd51 100644 --- a/test/load-generator/boulder-calls.go +++ b/test/load-generator/boulder-calls.go @@ -31,7 +31,7 @@ import ( var ( // stringToOperation maps a configured plan action to a function that can // operate on a state/context. - stringToOperation = map[string]func(*State, *context) error{ + stringToOperation = map[string]func(*State, *acmeCache) error{ "newAccount": newAccount, "getAccount": getAccount, "newOrder": newOrder, @@ -60,8 +60,8 @@ type OrderJSON struct { } // getAccount takes a randomly selected v2 account from `state.accts` and puts it -// into `ctx.acct`. The context `nonceSource` is also populated as convenience. -func getAccount(s *State, ctx *context) error { +// into `c.acct`. The context `nonceSource` is also populated as convenience. +func getAccount(s *State, c *acmeCache) error { s.rMu.RLock() defer s.rMu.RUnlock() @@ -71,8 +71,8 @@ func getAccount(s *State, ctx *context) error { } // Select a random account from the state and put it into the context - ctx.acct = s.accts[mrand.Intn(len(s.accts))] - ctx.ns = &nonceSource{s: s} + c.acct = s.accts[mrand.Intn(len(s.accts))] + c.ns = &nonceSource{s: s} return nil } @@ -81,11 +81,11 @@ func getAccount(s *State, ctx *context) error { // then `newAccount` puts an existing account from the state into the context, // otherwise it creates a new account and puts it into both the state and the // context. -func newAccount(s *State, ctx *context) error { +func newAccount(s *State, c *acmeCache) error { // Check the max regs and if exceeded, just return an existing account instead // of creating a new one. if s.maxRegs != 0 && s.numAccts() >= s.maxRegs { - return getAccount(s, ctx) + return getAccount(s, c) } // Create a random signing key @@ -93,10 +93,10 @@ func newAccount(s *State, ctx *context) error { if err != nil { return err } - ctx.acct = &account{ + c.acct = &account{ key: signKey, } - ctx.ns = &nonceSource{s: s} + c.ns = &nonceSource{s: s} // Prepare an account registration message body reqBody := struct { @@ -117,7 +117,7 @@ func newAccount(s *State, ctx *context) error { // Sign the new account registration body using a JWS with an embedded JWK // because we do not have a key ID from the server yet. newAccountURL := s.directory.EndpointURL(acme.NewAccountEndpoint) - jws, err := ctx.signEmbeddedV2Request(reqBodyStr, newAccountURL) + jws, err := c.signEmbeddedV2Request(reqBodyStr, newAccountURL) if err != nil { return err } @@ -126,7 +126,7 @@ func newAccount(s *State, ctx *context) error { resp, err := s.post( newAccountURL, bodyBuf, - ctx.ns, + c.ns, string(acme.NewAccountEndpoint), http.StatusCreated) if err != nil { @@ -140,10 +140,10 @@ func newAccount(s *State, ctx *context) error { if locHeader == "" { return fmt.Errorf("%s, bad response - no Location header with account ID", newAccountURL) } - ctx.acct.id = locHeader + c.acct.id = locHeader // Add the account to the state - s.addAccount(ctx.acct) + s.addAccount(c.acct) return nil } @@ -160,7 +160,7 @@ func randDomain(base string) string { // newOrder creates a new pending order object for a random set of domains using // the context's account. -func newOrder(s *State, ctx *context) error { +func newOrder(s *State, c *acmeCache) error { // Pick a random number of names within the constraints of the maxNamesPerCert // parameter orderSize := 1 + mrand.Intn(s.maxNamesPerCert-1) @@ -187,7 +187,7 @@ func newOrder(s *State, ctx *context) error { // Sign the new order request with the context account's key/key ID newOrderURL := s.directory.EndpointURL(acme.NewOrderEndpoint) - jws, err := ctx.signKeyIDV2Request(initOrderStr, newOrderURL) + jws, err := c.signKeyIDV2Request(initOrderStr, newOrderURL) if err != nil { return err } @@ -196,7 +196,7 @@ func newOrder(s *State, ctx *context) error { resp, err := s.post( newOrderURL, bodyBuf, - ctx.ns, + c.ns, string(acme.NewOrderEndpoint), http.StatusCreated) if err != nil { @@ -223,24 +223,24 @@ func newOrder(s *State, ctx *context) error { orderJSON.URL = orderURL // Store the pending order in the context - ctx.pendingOrders = append(ctx.pendingOrders, &orderJSON) + c.pendingOrders = append(c.pendingOrders, &orderJSON) return nil } // popPendingOrder *removes* a random pendingOrder from the context, returning // it. -func popPendingOrder(ctx *context) *OrderJSON { - orderIndex := mrand.Intn(len(ctx.pendingOrders)) - order := ctx.pendingOrders[orderIndex] - ctx.pendingOrders = append(ctx.pendingOrders[:orderIndex], ctx.pendingOrders[orderIndex+1:]...) +func popPendingOrder(c *acmeCache) *OrderJSON { + orderIndex := mrand.Intn(len(c.pendingOrders)) + order := c.pendingOrders[orderIndex] + c.pendingOrders = append(c.pendingOrders[:orderIndex], c.pendingOrders[orderIndex+1:]...) return order } // getAuthorization fetches an authorization by GET-ing the provided URL. It // records the latency and result of the GET operation in the state. -func getAuthorization(s *State, ctx *context, url string) (*core.Authorization, error) { +func getAuthorization(s *State, c *acmeCache, url string) (*core.Authorization, error) { latencyTag := "/acme/authz/{ID}" - resp, err := postAsGet(s, ctx, url, latencyTag) + resp, err := postAsGet(s, c, url, latencyTag) // If there was an error, note the state and return if err != nil { return nil, fmt.Errorf("%s bad response: %s", url, err) @@ -269,7 +269,7 @@ func getAuthorization(s *State, ctx *context, url string) (*core.Authorization, // HTTP-01 challenge using the context's account and the state's challenge // server. Aftering POSTing the authorization's HTTP-01 challenge the // authorization will be polled waiting for a state change. -func completeAuthorization(authz *core.Authorization, s *State, ctx *context) error { +func completeAuthorization(authz *core.Authorization, s *State, c *acmeCache) error { // Skip if the authz isn't pending if authz.Status != core.StatusPending { return nil @@ -283,7 +283,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er } // Compute the key authorization from the context account's key - jwk := &jose.JSONWebKey{Key: &ctx.acct.key.PublicKey} + jwk := &jose.JSONWebKey{Key: &c.acct.key.PublicKey} thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { return err @@ -311,7 +311,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er } // Prepare the Challenge POST body - jws, err := ctx.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL) + jws, err := c.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL) if err != nil { return err } @@ -320,7 +320,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er resp, err := s.post( chalToSolve.URL, requestPayload, - ctx.ns, + c.ns, "/acme/challenge/{ID}", // We want all challenge POST latencies to be grouped http.StatusOK, ) @@ -337,7 +337,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er // Poll the authorization waiting for the challenge response to be recorded in // a change of state. The polling may sleep and retry a few times if required - err = pollAuthorization(authz, s, ctx) + err = pollAuthorization(authz, s, c) if err != nil { return err } @@ -351,11 +351,11 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er // be valid. If the status is invalid, or if three GETs do not produce the // correct authorization state an error is returned. If no error is returned // then the authorization is valid and ready. -func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error { +func pollAuthorization(authz *core.Authorization, s *State, c *acmeCache) error { authzURL := authz.ID for i := 0; i < 3; i++ { // Fetch the authz by its URL - authz, err := getAuthorization(s, ctx, authzURL) + authz, err := getAuthorization(s, c, authzURL) if err != nil { return nil } @@ -377,25 +377,25 @@ func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error // authorization's HTTP-01 challenge using the context's account, and finally // placing the now-ready-to-be-finalized order into the context's list of // fulfilled orders. -func fulfillOrder(s *State, ctx *context) error { +func fulfillOrder(s *State, c *acmeCache) error { // There must be at least one pending order in the context to fulfill - if len(ctx.pendingOrders) == 0 { + if len(c.pendingOrders) == 0 { return errors.New("no pending orders to fulfill") } // Get an order to fulfill from the context - order := popPendingOrder(ctx) + order := popPendingOrder(c) // Each of its authorizations need to be processed for _, url := range order.Authorizations { // Fetch the authz by its URL - authz, err := getAuthorization(s, ctx, url) + authz, err := getAuthorization(s, c, url) if err != nil { return err } // Complete the authorization by solving a challenge - err = completeAuthorization(authz, s, ctx) + err = completeAuthorization(authz, s, c) if err != nil { return err } @@ -403,16 +403,16 @@ func fulfillOrder(s *State, ctx *context) error { // Once all of the authorizations have been fulfilled the order is fulfilled // and ready for future finalization. - ctx.fulfilledOrders = append(ctx.fulfilledOrders, order.URL) + c.fulfilledOrders = append(c.fulfilledOrders, order.URL) return nil } // getOrder GETs an order by URL, returning an OrderJSON object. It tracks the // latency of the GET operation in the provided state. -func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) { +func getOrder(s *State, c *acmeCache, url string) (*OrderJSON, error) { latencyTag := "/acme/order/{ID}" // POST-as-GET the order URL - resp, err := postAsGet(s, ctx, url, latencyTag) + resp, err := postAsGet(s, c, url, latencyTag) // If there was an error, track that result if err != nil { return nil, fmt.Errorf("%s bad response: %s", url, err) @@ -440,10 +440,10 @@ func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) { // valid such that a certificate URL for the order is known. Three attempts are // made to check the order status, sleeping 3s between each. If these attempts // expire without the status becoming valid an error is returned. -func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, error) { +func pollOrderForCert(order *OrderJSON, s *State, c *acmeCache) (*OrderJSON, error) { for i := 0; i < 3; i++ { // Fetch the order by its URL - order, err := getOrder(s, ctx, order.URL) + order, err := getOrder(s, c, order.URL) if err != nil { return nil, err } @@ -463,10 +463,10 @@ func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, err // popFulfilledOrder **removes** a fulfilled order from the context, returning // it. Fulfilled orders have all of their authorizations satisfied. -func popFulfilledOrder(ctx *context) string { - orderIndex := mrand.Intn(len(ctx.fulfilledOrders)) - order := ctx.fulfilledOrders[orderIndex] - ctx.fulfilledOrders = append(ctx.fulfilledOrders[:orderIndex], ctx.fulfilledOrders[orderIndex+1:]...) +func popFulfilledOrder(c *acmeCache) string { + orderIndex := mrand.Intn(len(c.fulfilledOrders)) + order := c.fulfilledOrders[orderIndex] + c.fulfilledOrders = append(c.fulfilledOrders[:orderIndex], c.fulfilledOrders[orderIndex+1:]...) return order } @@ -475,15 +475,15 @@ func popFulfilledOrder(ctx *context) string { // `certKey`. The order is then polled for the status to change to valid so that // the certificate URL can be added to the context. The context's `certs` list // is updated with the URL for the order's certificate. -func finalizeOrder(s *State, ctx *context) error { +func finalizeOrder(s *State, c *acmeCache) error { // There must be at least one fulfilled order in the context - if len(ctx.fulfilledOrders) < 1 { + if len(c.fulfilledOrders) < 1 { return errors.New("No fulfilled orders in the context ready to be finalized") } // Pop a fulfilled order to process, and then GET its contents - orderID := popFulfilledOrder(ctx) - order, err := getOrder(s, ctx, orderID) + orderID := popFulfilledOrder(c) + order, err := getOrder(s, c, orderID) if err != nil { return err } @@ -519,7 +519,7 @@ func finalizeOrder(s *State, ctx *context) error { ) // Sign the request body with the context's account key/keyID - jws, err := ctx.signKeyIDV2Request([]byte(request), finalizeURL) + jws, err := c.signKeyIDV2Request([]byte(request), finalizeURL) if err != nil { return err } @@ -528,7 +528,7 @@ func finalizeOrder(s *State, ctx *context) error { resp, err := s.post( finalizeURL, requestPayload, - ctx.ns, + c.ns, "/acme/order/finalize", // We want all order finalizations to be grouped. http.StatusOK, ) @@ -544,7 +544,7 @@ func finalizeOrder(s *State, ctx *context) error { } // Poll the order waiting for the certificate to be ready - completedOrder, err := pollOrderForCert(order, s, ctx) + completedOrder, err := pollOrderForCert(order, s, c) if err != nil { return err } @@ -556,8 +556,8 @@ func finalizeOrder(s *State, ctx *context) error { } // Append the certificate URL into the context's list of certificates - ctx.certs = append(ctx.certs, certURL) - ctx.finalizedOrders = append(ctx.finalizedOrders, order.URL) + c.certs = append(c.certs, certURL) + c.finalizedOrders = append(c.finalizedOrders, order.URL) return nil } @@ -567,27 +567,27 @@ func finalizeOrder(s *State, ctx *context) error { // responsible for closing the HTTP response body. // // See RFC 8555 Section 6.3 for more information on POST-as-GET requests. -func postAsGet(s *State, ctx *context, url string, latencyTag string) (*http.Response, error) { +func postAsGet(s *State, c *acmeCache, url string, latencyTag string) (*http.Response, error) { // Create the POST-as-GET request JWS - jws, err := ctx.signKeyIDV2Request([]byte(""), url) + jws, err := c.signKeyIDV2Request([]byte(""), url) if err != nil { return nil, err } requestPayload := []byte(jws.FullSerialize()) - return s.post(url, requestPayload, ctx.ns, latencyTag, http.StatusOK) + return s.post(url, requestPayload, c.ns, latencyTag, http.StatusOK) } -func popCertificate(ctx *context) string { - certIndex := mrand.Intn(len(ctx.certs)) - certURL := ctx.certs[certIndex] - ctx.certs = append(ctx.certs[:certIndex], ctx.certs[certIndex+1:]...) +func popCertificate(c *acmeCache) string { + certIndex := mrand.Intn(len(c.certs)) + certURL := c.certs[certIndex] + c.certs = append(c.certs[:certIndex], c.certs[certIndex+1:]...) return certURL } -func getCert(s *State, ctx *context, url string) ([]byte, error) { +func getCert(s *State, c *acmeCache, url string) ([]byte, error) { latencyTag := "/acme/cert/{serial}" - resp, err := postAsGet(s, ctx, url, latencyTag) + resp, err := postAsGet(s, c, url, latencyTag) if err != nil { return nil, fmt.Errorf("%s bad response: %s", url, err) } @@ -599,8 +599,8 @@ func getCert(s *State, ctx *context, url string) ([]byte, error) { // and sends a revocation request for the certificate to the ACME server. // The revocation request is signed with the account key rather than the certificate // key. -func revokeCertificate(s *State, ctx *context) error { - if len(ctx.certs) < 1 { +func revokeCertificate(s *State, c *acmeCache) error { + if len(c.certs) < 1 { return errors.New("No certificates in the context that can be revoked") } @@ -608,8 +608,8 @@ func revokeCertificate(s *State, ctx *context) error { return nil } - certURL := popCertificate(ctx) - certPEM, err := getCert(s, ctx, certURL) + certURL := popCertificate(c) + certPEM, err := getCert(s, c, certURL) if err != nil { return err } @@ -630,7 +630,7 @@ func revokeCertificate(s *State, ctx *context) error { revokeURL := s.directory.EndpointURL(acme.RevokeCertEndpoint) // TODO(roland): randomly use the certificate key to sign the request instead of // the account key - jws, err := ctx.signKeyIDV2Request(revokeJSON, revokeURL) + jws, err := c.signKeyIDV2Request(revokeJSON, revokeURL) if err != nil { return err } @@ -639,7 +639,7 @@ func revokeCertificate(s *State, ctx *context) error { resp, err := s.post( revokeURL, requestPayload, - ctx.ns, + c.ns, "/acme/revoke-cert", http.StatusOK, ) diff --git a/test/load-generator/main.go b/test/load-generator/main.go index a5586a064..1baed0673 100644 --- a/test/load-generator/main.go +++ b/test/load-generator/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "fmt" @@ -118,9 +119,11 @@ func main() { "HTTPOneAddrs, TLSALPNOneAddrs or DNSAddrs\n") } - go cmd.CatchSignals(nil, nil) + ctx, cancel := context.WithCancel(context.Background()) + go cmd.CatchSignals(cancel) err = s.Run( + ctx, config.HTTPOneAddrs, config.TLSALPNOneAddrs, config.DNSAddrs, diff --git a/test/load-generator/state.go b/test/load-generator/state.go index cf04e1495..bbb4e866f 100644 --- a/test/load-generator/state.go +++ b/test/load-generator/state.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -15,14 +16,12 @@ import ( "net" "net/http" "os" - "os/signal" "reflect" "runtime" "sort" "strings" "sync" "sync/atomic" - "syscall" "time" "gopkg.in/go-jose/go-jose.v2" @@ -54,7 +53,7 @@ func (acct *account) update(finalizedOrders, certs []string) { acct.certs = append(acct.certs, certs...) } -type context struct { +type acmeCache struct { // The current V2 account (may be nil for legacy load generation) acct *account // Pending orders waiting for authorization challenge validation @@ -70,12 +69,12 @@ type context struct { ns *nonceSource } -// signEmbeddedV2Request signs the provided request data using the context's +// signEmbeddedV2Request signs the provided request data using the acmeCache's // account's private key. The provided URL is set as a protected header per ACME // v2 JWS standards. The resulting JWS contains an **embedded** JWK - this makes // this function primarily applicable to new account requests where no key ID is // known. -func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) { +func (c *acmeCache) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) { // Create a signing key for the account's private key signingKey := jose.SigningKey{ Key: c.acct.key, @@ -101,13 +100,13 @@ func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebS return signed, nil } -// signKeyIDV2Request signs the provided request data using the context's +// signKeyIDV2Request signs the provided request data using the acmeCache's // account's private key. The provided URL is set as a protected header per ACME // v2 JWS standards. The resulting JWS contains a Key ID header that is -// populated using the context's account's ID. This is the default JWS signing +// populated using the acmeCache's account's ID. This is the default JWS signing // style for ACME v2 requests and should be used everywhere but where the key ID // is unknown (e.g. new-account requests where an account doesn't exist yet). -func (c *context) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) { +func (c *acmeCache) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) { // Create a JWK with the account's private key and key ID jwk := &jose.JSONWebKey{ Key: c.acct.key, @@ -168,7 +167,7 @@ type State struct { realIP string certKey *ecdsa.PrivateKey - operations []func(*State, *context) error + operations []func(*State, *acmeCache) error rMu sync.RWMutex @@ -349,6 +348,7 @@ func New( // Run runs the WFE load-generator func (s *State) Run( + ctx context.Context, httpOneAddrs []string, tlsALPNOneAddrs []string, dnsAddrs []string, @@ -387,8 +387,6 @@ func (s *State) Run( // Run sending loop stop := make(chan bool, 1) - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) fmt.Println("[+] Beginning execution plan") i := int64(0) go func() { @@ -429,8 +427,8 @@ func (s *State) Run( select { case <-time.After(p.Runtime): fmt.Println("[+] Execution plan finished") - case sig := <-sigs: - fmt.Printf("[!] Execution plan interrupted: %s caught\n", sig.String()) + case <-ctx.Done(): + fmt.Println("[!] Execution plan cancelled") } stop <- true fmt.Println("[+] Waiting for pending flows to finish before killing challenge server") @@ -586,19 +584,19 @@ func (s *State) addAccount(acct *account) { func (s *State) sendCall() { defer s.wg.Done() - ctx := &context{} + c := &acmeCache{} for _, op := range s.operations { - err := op(s, ctx) + err := op(s, c) if err != nil { method := runtime.FuncForPC(reflect.ValueOf(op).Pointer()).Name() fmt.Printf("[FAILED] %s: %s\n", method, err) break } } - // If the context's V2 account isn't nil, update it based on the context's + // If the acmeCache's V2 account isn't nil, update it based on the cache's // finalizedOrders and certs. - if ctx.acct != nil { - ctx.acct.update(ctx.finalizedOrders, ctx.certs) + if c.acct != nil { + c.acct.update(c.finalizedOrders, c.certs) } } diff --git a/test/mail-test-srv/main.go b/test/mail-test-srv/main.go index d3ee84e42..3d13532a5 100644 --- a/test/mail-test-srv/main.go +++ b/test/mail-test-srv/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "crypto/tls" "flag" "fmt" @@ -186,11 +187,19 @@ scan: } } -func (srv *mailSrv) serveSMTP(l net.Listener) error { +func (srv *mailSrv) serveSMTP(ctx context.Context, l net.Listener) error { for { conn, err := l.Accept() if err != nil { - return err + // If the accept call returned an error because the listener has been + // closed, then the context should have been canceled too. In that case, + // ignore the error. + select { + case <-ctx.Done(): + return nil + default: + return err + } } go srv.handleConn(conn) } @@ -233,10 +242,10 @@ func main() { } }() - go cmd.CatchSignals(nil, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - err = srv.serveSMTP(l) - if err != nil { - log.Fatalln(err, "Failed to accept connection") - } + go cmd.FailOnError(srv.serveSMTP(ctx, l), "Failed to accept connection") + + cmd.WaitForSignal() } diff --git a/test/s3-test-srv/main.go b/test/s3-test-srv/main.go index 70408417f..ae998bb23 100644 --- a/test/s3-test-srv/main.go +++ b/test/s3-test-srv/main.go @@ -1,12 +1,13 @@ package main import ( + "context" "flag" "fmt" "io" - "log" "net/http" "sync" + "time" "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" @@ -93,9 +94,23 @@ func main() { http.HandleFunc("/clear", srv.handleClear) http.HandleFunc("/query", srv.handleQuery) - // The gosec linter complains that timeouts cannot be set here. That's fine, - // because this is test-only code. - ////nolint:gosec - go log.Fatal(http.ListenAndServe(*listenAddr, nil)) - cmd.CatchSignals(nil, nil) + s := http.Server{ + ReadTimeout: 30 * time.Second, + Addr: *listenAddr, + } + + go func() { + err := s.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + cmd.FailOnError(err, "Running TLS server") + } + }() + + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + s.Shutdown(ctx) + }() + + cmd.WaitForSignal() }