Fix non-gRPC process cleanup and exit (#6808)
Although #6771 significantly cleaned up how gRPC services stop and clean up, it didn't make any changes to our HTTP servers or our non-server (e.g. crl-updater, log-validator) processes. This change finishes the work. Add a new helper method cmd.WaitForSignal, which simply blocks until one of the three signals we care about is received. This easily replaces all calls to cmd.CatchSignals which passed `nil` as the callback argument, with the added advantage that it doesn't call os.Exit() and therefore allows deferred cleanup functions to execute. This new function is intended to be the last line of main(), allowing the whole process to exit once it returns. Reimplement cmd.CatchSignals as a thin wrapper around cmd.WaitForSignal, but with the added callback functionality. Also remove the os.Exit() call from CatchSignals, so that the main goroutine is allowed to finish whatever it's doing, call deferred functions, and exit naturally. Update all of our non-gRPC binaries to use one of these two functions. The vast majority use WaitForSignal, as they run their main processing loop in a background goroutine. A few (particularly those that can run either in run-once or in daemonized mode) still use CatchSignals, since their primary processing happens directly on the main goroutine. The changes to //test/load-generator are the most invasive, simply because that binary needed to have a context plumbed into it for proper cancellation, but it already had a custom struct type named "context" which needed to be renamed to avoid shadowing. Fixes https://github.com/letsencrypt/boulder/issues/6794
This commit is contained in:
parent
98fa0f07b4
commit
bd1d27b8e8
|
@ -535,9 +535,6 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
// The gosec linter complains that ReadHeaderTimeout is not set. That's fine,
|
||||
// because that field inherits its value from ReadTimeout.
|
||||
////nolint:gosec
|
||||
tlsSrv := http.Server{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 120 * time.Second,
|
||||
|
@ -555,20 +552,18 @@ func main() {
|
|||
}()
|
||||
}
|
||||
|
||||
done := make(chan bool)
|
||||
go cmd.CatchSignals(logger, func() {
|
||||
// When main is ready to exit (because it has received a shutdown signal),
|
||||
// gracefully shutdown the servers. Calling these shutdown functions causes
|
||||
// ListenAndServe() and ListenAndServeTLS() to immediately return, then waits
|
||||
// for any lingering connection-handling goroutines to finish their work.
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
_ = tlsSrv.Shutdown(ctx)
|
||||
done <- true
|
||||
})
|
||||
}()
|
||||
|
||||
// https://godoc.org/net/http#Server.Shutdown:
|
||||
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
||||
// immediately return ErrServerClosed. Make sure the program doesn't exit and
|
||||
// waits instead for Shutdown to return.
|
||||
<-done
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -2,6 +2,7 @@ package notmain
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"os"
|
||||
"time"
|
||||
|
@ -165,15 +166,17 @@ func main() {
|
|||
cmd.FailOnError(err, "Failed to create crl-updater")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go cmd.CatchSignals(logger, cancel)
|
||||
go cmd.CatchSignals(cancel)
|
||||
|
||||
if *runOnce {
|
||||
err = u.Tick(ctx, clk.Now())
|
||||
cmd.FailOnError(err, "")
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.FailOnError(err, "")
|
||||
}
|
||||
} else {
|
||||
err = u.Run(ctx)
|
||||
if err != nil {
|
||||
logger.Err(err.Error())
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.FailOnError(err, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -803,6 +803,11 @@ func main() {
|
|||
defer logger.AuditPanic()
|
||||
logger.Info(cmd.VersionString())
|
||||
|
||||
if *daemon && c.Mailer.Frequency.Duration == 0 {
|
||||
fmt.Fprintln(os.Stderr, "mailer.frequency is not set in the JSON config")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *certLimit > 0 {
|
||||
c.Mailer.CertLimit = *certLimit
|
||||
}
|
||||
|
@ -913,31 +918,26 @@ func main() {
|
|||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go cmd.CatchSignals(logger, func() {
|
||||
cancel()
|
||||
select {} // wait for the `findExpiringCertificates` calls below to exit
|
||||
})
|
||||
go cmd.CatchSignals(cancel)
|
||||
|
||||
if *daemon {
|
||||
if c.Mailer.Frequency.Duration == 0 {
|
||||
fmt.Fprintln(os.Stderr, "mailer.Frequency is not set in the JSON config")
|
||||
os.Exit(1)
|
||||
}
|
||||
t := time.NewTicker(c.Mailer.Frequency.Duration)
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
err = m.findExpiringCertificates(ctx)
|
||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
os.Exit(0)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = m.findExpiringCertificates(ctx)
|
||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ func main() {
|
|||
tailers = append(tailers, t)
|
||||
}
|
||||
|
||||
cmd.CatchSignals(logger, func() {
|
||||
defer func() {
|
||||
for _, t := range tailers {
|
||||
// The tail module seems to have a race condition that will generate
|
||||
// errors like this on shutdown:
|
||||
|
@ -211,7 +211,9 @@ func main() {
|
|||
_ = t.Stop()
|
||||
t.Cleanup()
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -227,25 +227,24 @@ as generated by Boulder's ceremony command.
|
|||
Handler: m,
|
||||
}
|
||||
|
||||
done := make(chan bool)
|
||||
go cmd.CatchSignals(logger, func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(),
|
||||
c.OCSPResponder.ShutdownStopTimeout.Duration)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
done <- true
|
||||
})
|
||||
|
||||
err = srv.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
cmd.FailOnError(err, "Running HTTP server")
|
||||
}
|
||||
|
||||
// https://godoc.org/net/http#Server.Shutdown:
|
||||
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
||||
// immediately return ErrServerClosed. Make sure the program doesn't exit and
|
||||
// waits instead for Shutdown to return.
|
||||
<-done
|
||||
// When main is ready to exit (because it has received a shutdown signal),
|
||||
// gracefully shutdown the servers. Calling these shutdown functions causes
|
||||
// ListenAndServe() to immediately return, cleaning up the server goroutines
|
||||
// as well, then waits for any lingering connection-handing goroutines to
|
||||
// finish and clean themselves up.
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(),
|
||||
c.OCSPResponder.ShutdownStopTimeout.Duration)
|
||||
defer cancel()
|
||||
srv.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
||||
// ocspMux partially implements the interface defined for http.ServeMux but doesn't implement
|
||||
|
|
26
cmd/shell.go
26
cmd/shell.go
|
@ -418,18 +418,26 @@ func VersionString() string {
|
|||
return fmt.Sprintf("Versions: %s=(%s %s) Golang=(%s) BuildHost=(%s)", command(), core.GetBuildID(), core.GetBuildTime(), runtime.Version(), core.GetBuildHost())
|
||||
}
|
||||
|
||||
// CatchSignals catches SIGTERM, SIGINT, SIGHUP and executes a callback
|
||||
// method before exiting
|
||||
func CatchSignals(logger blog.Logger, callback func()) {
|
||||
// CatchSignals blocks until a SIGTERM, SIGINT, or SIGHUP is received, then
|
||||
// executes the given callback. The callback should not block, it should simply
|
||||
// signal other goroutines (particularly the main goroutine) to clean themselves
|
||||
// up and exit. This function is intended to be called in its own goroutine,
|
||||
// while the main goroutine waits for an indication that the other goroutines
|
||||
// have exited cleanly.
|
||||
func CatchSignals(callback func()) {
|
||||
WaitForSignal()
|
||||
callback()
|
||||
}
|
||||
|
||||
// WaitForSignal blocks until a SIGTERM, SIGINT, or SIGHUP is received. It then
|
||||
// returns, allowing execution to resume, generally allowing a main() function
|
||||
// to return and trigger and deferred cleanup functions. This function is
|
||||
// intended to be called directly from the main goroutine, while a gRPC or HTTP
|
||||
// server runs in a background goroutine.
|
||||
func WaitForSignal() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
signal.Notify(sigChan, syscall.SIGHUP)
|
||||
|
||||
<-sigChan
|
||||
if callback != nil {
|
||||
callback()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
|
|
@ -5,10 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||
"github.com/jmhodges/clock"
|
||||
|
@ -175,15 +172,10 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R
|
|||
// Start a goroutine which listens for a termination signal, and then
|
||||
// gracefully stops the gRPC server. This in turn causes the start() function
|
||||
// to exit, allowing its caller (generally a main() function) to exit.
|
||||
go func() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
signal.Notify(sigChan, syscall.SIGHUP)
|
||||
<-sigChan
|
||||
go cmd.CatchSignals(func() {
|
||||
healthSrv.Shutdown()
|
||||
server.GracefulStop()
|
||||
}()
|
||||
})
|
||||
|
||||
return start, nil
|
||||
}
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/letsencrypt/boulder/akamai"
|
||||
"github.com/letsencrypt/boulder/cmd"
|
||||
|
@ -92,9 +93,23 @@ func main() {
|
|||
w.Write(resp)
|
||||
})
|
||||
|
||||
// The gosec linter complains that timeouts cannot be set here. That's fine,
|
||||
// because this is test-only code.
|
||||
////nolint:gosec
|
||||
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
|
||||
cmd.CatchSignals(nil, nil)
|
||||
s := http.Server{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
Addr: *listenAddr,
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := s.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
cmd.FailOnError(err, "Running TLS server")
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
s.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
|
|
@ -257,5 +257,5 @@ func main() {
|
|||
for _, p := range c.Personalities {
|
||||
go runPersonality(p)
|
||||
}
|
||||
cmd.CatchSignals(nil, nil)
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import (
|
|||
var (
|
||||
// stringToOperation maps a configured plan action to a function that can
|
||||
// operate on a state/context.
|
||||
stringToOperation = map[string]func(*State, *context) error{
|
||||
stringToOperation = map[string]func(*State, *acmeCache) error{
|
||||
"newAccount": newAccount,
|
||||
"getAccount": getAccount,
|
||||
"newOrder": newOrder,
|
||||
|
@ -60,8 +60,8 @@ type OrderJSON struct {
|
|||
}
|
||||
|
||||
// getAccount takes a randomly selected v2 account from `state.accts` and puts it
|
||||
// into `ctx.acct`. The context `nonceSource` is also populated as convenience.
|
||||
func getAccount(s *State, ctx *context) error {
|
||||
// into `c.acct`. The context `nonceSource` is also populated as convenience.
|
||||
func getAccount(s *State, c *acmeCache) error {
|
||||
s.rMu.RLock()
|
||||
defer s.rMu.RUnlock()
|
||||
|
||||
|
@ -71,8 +71,8 @@ func getAccount(s *State, ctx *context) error {
|
|||
}
|
||||
|
||||
// Select a random account from the state and put it into the context
|
||||
ctx.acct = s.accts[mrand.Intn(len(s.accts))]
|
||||
ctx.ns = &nonceSource{s: s}
|
||||
c.acct = s.accts[mrand.Intn(len(s.accts))]
|
||||
c.ns = &nonceSource{s: s}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -81,11 +81,11 @@ func getAccount(s *State, ctx *context) error {
|
|||
// then `newAccount` puts an existing account from the state into the context,
|
||||
// otherwise it creates a new account and puts it into both the state and the
|
||||
// context.
|
||||
func newAccount(s *State, ctx *context) error {
|
||||
func newAccount(s *State, c *acmeCache) error {
|
||||
// Check the max regs and if exceeded, just return an existing account instead
|
||||
// of creating a new one.
|
||||
if s.maxRegs != 0 && s.numAccts() >= s.maxRegs {
|
||||
return getAccount(s, ctx)
|
||||
return getAccount(s, c)
|
||||
}
|
||||
|
||||
// Create a random signing key
|
||||
|
@ -93,10 +93,10 @@ func newAccount(s *State, ctx *context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.acct = &account{
|
||||
c.acct = &account{
|
||||
key: signKey,
|
||||
}
|
||||
ctx.ns = &nonceSource{s: s}
|
||||
c.ns = &nonceSource{s: s}
|
||||
|
||||
// Prepare an account registration message body
|
||||
reqBody := struct {
|
||||
|
@ -117,7 +117,7 @@ func newAccount(s *State, ctx *context) error {
|
|||
// Sign the new account registration body using a JWS with an embedded JWK
|
||||
// because we do not have a key ID from the server yet.
|
||||
newAccountURL := s.directory.EndpointURL(acme.NewAccountEndpoint)
|
||||
jws, err := ctx.signEmbeddedV2Request(reqBodyStr, newAccountURL)
|
||||
jws, err := c.signEmbeddedV2Request(reqBodyStr, newAccountURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -126,7 +126,7 @@ func newAccount(s *State, ctx *context) error {
|
|||
resp, err := s.post(
|
||||
newAccountURL,
|
||||
bodyBuf,
|
||||
ctx.ns,
|
||||
c.ns,
|
||||
string(acme.NewAccountEndpoint),
|
||||
http.StatusCreated)
|
||||
if err != nil {
|
||||
|
@ -140,10 +140,10 @@ func newAccount(s *State, ctx *context) error {
|
|||
if locHeader == "" {
|
||||
return fmt.Errorf("%s, bad response - no Location header with account ID", newAccountURL)
|
||||
}
|
||||
ctx.acct.id = locHeader
|
||||
c.acct.id = locHeader
|
||||
|
||||
// Add the account to the state
|
||||
s.addAccount(ctx.acct)
|
||||
s.addAccount(c.acct)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,7 @@ func randDomain(base string) string {
|
|||
|
||||
// newOrder creates a new pending order object for a random set of domains using
|
||||
// the context's account.
|
||||
func newOrder(s *State, ctx *context) error {
|
||||
func newOrder(s *State, c *acmeCache) error {
|
||||
// Pick a random number of names within the constraints of the maxNamesPerCert
|
||||
// parameter
|
||||
orderSize := 1 + mrand.Intn(s.maxNamesPerCert-1)
|
||||
|
@ -187,7 +187,7 @@ func newOrder(s *State, ctx *context) error {
|
|||
|
||||
// Sign the new order request with the context account's key/key ID
|
||||
newOrderURL := s.directory.EndpointURL(acme.NewOrderEndpoint)
|
||||
jws, err := ctx.signKeyIDV2Request(initOrderStr, newOrderURL)
|
||||
jws, err := c.signKeyIDV2Request(initOrderStr, newOrderURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ func newOrder(s *State, ctx *context) error {
|
|||
resp, err := s.post(
|
||||
newOrderURL,
|
||||
bodyBuf,
|
||||
ctx.ns,
|
||||
c.ns,
|
||||
string(acme.NewOrderEndpoint),
|
||||
http.StatusCreated)
|
||||
if err != nil {
|
||||
|
@ -223,24 +223,24 @@ func newOrder(s *State, ctx *context) error {
|
|||
orderJSON.URL = orderURL
|
||||
|
||||
// Store the pending order in the context
|
||||
ctx.pendingOrders = append(ctx.pendingOrders, &orderJSON)
|
||||
c.pendingOrders = append(c.pendingOrders, &orderJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// popPendingOrder *removes* a random pendingOrder from the context, returning
|
||||
// it.
|
||||
func popPendingOrder(ctx *context) *OrderJSON {
|
||||
orderIndex := mrand.Intn(len(ctx.pendingOrders))
|
||||
order := ctx.pendingOrders[orderIndex]
|
||||
ctx.pendingOrders = append(ctx.pendingOrders[:orderIndex], ctx.pendingOrders[orderIndex+1:]...)
|
||||
func popPendingOrder(c *acmeCache) *OrderJSON {
|
||||
orderIndex := mrand.Intn(len(c.pendingOrders))
|
||||
order := c.pendingOrders[orderIndex]
|
||||
c.pendingOrders = append(c.pendingOrders[:orderIndex], c.pendingOrders[orderIndex+1:]...)
|
||||
return order
|
||||
}
|
||||
|
||||
// getAuthorization fetches an authorization by GET-ing the provided URL. It
|
||||
// records the latency and result of the GET operation in the state.
|
||||
func getAuthorization(s *State, ctx *context, url string) (*core.Authorization, error) {
|
||||
func getAuthorization(s *State, c *acmeCache, url string) (*core.Authorization, error) {
|
||||
latencyTag := "/acme/authz/{ID}"
|
||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
||||
resp, err := postAsGet(s, c, url, latencyTag)
|
||||
// If there was an error, note the state and return
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||
|
@ -269,7 +269,7 @@ func getAuthorization(s *State, ctx *context, url string) (*core.Authorization,
|
|||
// HTTP-01 challenge using the context's account and the state's challenge
|
||||
// server. Aftering POSTing the authorization's HTTP-01 challenge the
|
||||
// authorization will be polled waiting for a state change.
|
||||
func completeAuthorization(authz *core.Authorization, s *State, ctx *context) error {
|
||||
func completeAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
|
||||
// Skip if the authz isn't pending
|
||||
if authz.Status != core.StatusPending {
|
||||
return nil
|
||||
|
@ -283,7 +283,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
|||
}
|
||||
|
||||
// Compute the key authorization from the context account's key
|
||||
jwk := &jose.JSONWebKey{Key: &ctx.acct.key.PublicKey}
|
||||
jwk := &jose.JSONWebKey{Key: &c.acct.key.PublicKey}
|
||||
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -311,7 +311,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
|||
}
|
||||
|
||||
// Prepare the Challenge POST body
|
||||
jws, err := ctx.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
|
||||
jws, err := c.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -320,7 +320,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
|||
resp, err := s.post(
|
||||
chalToSolve.URL,
|
||||
requestPayload,
|
||||
ctx.ns,
|
||||
c.ns,
|
||||
"/acme/challenge/{ID}", // We want all challenge POST latencies to be grouped
|
||||
http.StatusOK,
|
||||
)
|
||||
|
@ -337,7 +337,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
|||
|
||||
// Poll the authorization waiting for the challenge response to be recorded in
|
||||
// a change of state. The polling may sleep and retry a few times if required
|
||||
err = pollAuthorization(authz, s, ctx)
|
||||
err = pollAuthorization(authz, s, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -351,11 +351,11 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
|||
// be valid. If the status is invalid, or if three GETs do not produce the
|
||||
// correct authorization state an error is returned. If no error is returned
|
||||
// then the authorization is valid and ready.
|
||||
func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error {
|
||||
func pollAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
|
||||
authzURL := authz.ID
|
||||
for i := 0; i < 3; i++ {
|
||||
// Fetch the authz by its URL
|
||||
authz, err := getAuthorization(s, ctx, authzURL)
|
||||
authz, err := getAuthorization(s, c, authzURL)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -377,25 +377,25 @@ func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error
|
|||
// authorization's HTTP-01 challenge using the context's account, and finally
|
||||
// placing the now-ready-to-be-finalized order into the context's list of
|
||||
// fulfilled orders.
|
||||
func fulfillOrder(s *State, ctx *context) error {
|
||||
func fulfillOrder(s *State, c *acmeCache) error {
|
||||
// There must be at least one pending order in the context to fulfill
|
||||
if len(ctx.pendingOrders) == 0 {
|
||||
if len(c.pendingOrders) == 0 {
|
||||
return errors.New("no pending orders to fulfill")
|
||||
}
|
||||
|
||||
// Get an order to fulfill from the context
|
||||
order := popPendingOrder(ctx)
|
||||
order := popPendingOrder(c)
|
||||
|
||||
// Each of its authorizations need to be processed
|
||||
for _, url := range order.Authorizations {
|
||||
// Fetch the authz by its URL
|
||||
authz, err := getAuthorization(s, ctx, url)
|
||||
authz, err := getAuthorization(s, c, url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Complete the authorization by solving a challenge
|
||||
err = completeAuthorization(authz, s, ctx)
|
||||
err = completeAuthorization(authz, s, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -403,16 +403,16 @@ func fulfillOrder(s *State, ctx *context) error {
|
|||
|
||||
// Once all of the authorizations have been fulfilled the order is fulfilled
|
||||
// and ready for future finalization.
|
||||
ctx.fulfilledOrders = append(ctx.fulfilledOrders, order.URL)
|
||||
c.fulfilledOrders = append(c.fulfilledOrders, order.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrder GETs an order by URL, returning an OrderJSON object. It tracks the
|
||||
// latency of the GET operation in the provided state.
|
||||
func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
|
||||
func getOrder(s *State, c *acmeCache, url string) (*OrderJSON, error) {
|
||||
latencyTag := "/acme/order/{ID}"
|
||||
// POST-as-GET the order URL
|
||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
||||
resp, err := postAsGet(s, c, url, latencyTag)
|
||||
// If there was an error, track that result
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||
|
@ -440,10 +440,10 @@ func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
|
|||
// valid such that a certificate URL for the order is known. Three attempts are
|
||||
// made to check the order status, sleeping 3s between each. If these attempts
|
||||
// expire without the status becoming valid an error is returned.
|
||||
func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, error) {
|
||||
func pollOrderForCert(order *OrderJSON, s *State, c *acmeCache) (*OrderJSON, error) {
|
||||
for i := 0; i < 3; i++ {
|
||||
// Fetch the order by its URL
|
||||
order, err := getOrder(s, ctx, order.URL)
|
||||
order, err := getOrder(s, c, order.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -463,10 +463,10 @@ func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, err
|
|||
|
||||
// popFulfilledOrder **removes** a fulfilled order from the context, returning
|
||||
// it. Fulfilled orders have all of their authorizations satisfied.
|
||||
func popFulfilledOrder(ctx *context) string {
|
||||
orderIndex := mrand.Intn(len(ctx.fulfilledOrders))
|
||||
order := ctx.fulfilledOrders[orderIndex]
|
||||
ctx.fulfilledOrders = append(ctx.fulfilledOrders[:orderIndex], ctx.fulfilledOrders[orderIndex+1:]...)
|
||||
func popFulfilledOrder(c *acmeCache) string {
|
||||
orderIndex := mrand.Intn(len(c.fulfilledOrders))
|
||||
order := c.fulfilledOrders[orderIndex]
|
||||
c.fulfilledOrders = append(c.fulfilledOrders[:orderIndex], c.fulfilledOrders[orderIndex+1:]...)
|
||||
return order
|
||||
}
|
||||
|
||||
|
@ -475,15 +475,15 @@ func popFulfilledOrder(ctx *context) string {
|
|||
// `certKey`. The order is then polled for the status to change to valid so that
|
||||
// the certificate URL can be added to the context. The context's `certs` list
|
||||
// is updated with the URL for the order's certificate.
|
||||
func finalizeOrder(s *State, ctx *context) error {
|
||||
func finalizeOrder(s *State, c *acmeCache) error {
|
||||
// There must be at least one fulfilled order in the context
|
||||
if len(ctx.fulfilledOrders) < 1 {
|
||||
if len(c.fulfilledOrders) < 1 {
|
||||
return errors.New("No fulfilled orders in the context ready to be finalized")
|
||||
}
|
||||
|
||||
// Pop a fulfilled order to process, and then GET its contents
|
||||
orderID := popFulfilledOrder(ctx)
|
||||
order, err := getOrder(s, ctx, orderID)
|
||||
orderID := popFulfilledOrder(c)
|
||||
order, err := getOrder(s, c, orderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -519,7 +519,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
|||
)
|
||||
|
||||
// Sign the request body with the context's account key/keyID
|
||||
jws, err := ctx.signKeyIDV2Request([]byte(request), finalizeURL)
|
||||
jws, err := c.signKeyIDV2Request([]byte(request), finalizeURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -528,7 +528,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
|||
resp, err := s.post(
|
||||
finalizeURL,
|
||||
requestPayload,
|
||||
ctx.ns,
|
||||
c.ns,
|
||||
"/acme/order/finalize", // We want all order finalizations to be grouped.
|
||||
http.StatusOK,
|
||||
)
|
||||
|
@ -544,7 +544,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
|||
}
|
||||
|
||||
// Poll the order waiting for the certificate to be ready
|
||||
completedOrder, err := pollOrderForCert(order, s, ctx)
|
||||
completedOrder, err := pollOrderForCert(order, s, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -556,8 +556,8 @@ func finalizeOrder(s *State, ctx *context) error {
|
|||
}
|
||||
|
||||
// Append the certificate URL into the context's list of certificates
|
||||
ctx.certs = append(ctx.certs, certURL)
|
||||
ctx.finalizedOrders = append(ctx.finalizedOrders, order.URL)
|
||||
c.certs = append(c.certs, certURL)
|
||||
c.finalizedOrders = append(c.finalizedOrders, order.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -567,27 +567,27 @@ func finalizeOrder(s *State, ctx *context) error {
|
|||
// responsible for closing the HTTP response body.
|
||||
//
|
||||
// See RFC 8555 Section 6.3 for more information on POST-as-GET requests.
|
||||
func postAsGet(s *State, ctx *context, url string, latencyTag string) (*http.Response, error) {
|
||||
func postAsGet(s *State, c *acmeCache, url string, latencyTag string) (*http.Response, error) {
|
||||
// Create the POST-as-GET request JWS
|
||||
jws, err := ctx.signKeyIDV2Request([]byte(""), url)
|
||||
jws, err := c.signKeyIDV2Request([]byte(""), url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestPayload := []byte(jws.FullSerialize())
|
||||
|
||||
return s.post(url, requestPayload, ctx.ns, latencyTag, http.StatusOK)
|
||||
return s.post(url, requestPayload, c.ns, latencyTag, http.StatusOK)
|
||||
}
|
||||
|
||||
func popCertificate(ctx *context) string {
|
||||
certIndex := mrand.Intn(len(ctx.certs))
|
||||
certURL := ctx.certs[certIndex]
|
||||
ctx.certs = append(ctx.certs[:certIndex], ctx.certs[certIndex+1:]...)
|
||||
func popCertificate(c *acmeCache) string {
|
||||
certIndex := mrand.Intn(len(c.certs))
|
||||
certURL := c.certs[certIndex]
|
||||
c.certs = append(c.certs[:certIndex], c.certs[certIndex+1:]...)
|
||||
return certURL
|
||||
}
|
||||
|
||||
func getCert(s *State, ctx *context, url string) ([]byte, error) {
|
||||
func getCert(s *State, c *acmeCache, url string) ([]byte, error) {
|
||||
latencyTag := "/acme/cert/{serial}"
|
||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
||||
resp, err := postAsGet(s, c, url, latencyTag)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||
}
|
||||
|
@ -599,8 +599,8 @@ func getCert(s *State, ctx *context, url string) ([]byte, error) {
|
|||
// and sends a revocation request for the certificate to the ACME server.
|
||||
// The revocation request is signed with the account key rather than the certificate
|
||||
// key.
|
||||
func revokeCertificate(s *State, ctx *context) error {
|
||||
if len(ctx.certs) < 1 {
|
||||
func revokeCertificate(s *State, c *acmeCache) error {
|
||||
if len(c.certs) < 1 {
|
||||
return errors.New("No certificates in the context that can be revoked")
|
||||
}
|
||||
|
||||
|
@ -608,8 +608,8 @@ func revokeCertificate(s *State, ctx *context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
certURL := popCertificate(ctx)
|
||||
certPEM, err := getCert(s, ctx, certURL)
|
||||
certURL := popCertificate(c)
|
||||
certPEM, err := getCert(s, c, certURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -630,7 +630,7 @@ func revokeCertificate(s *State, ctx *context) error {
|
|||
revokeURL := s.directory.EndpointURL(acme.RevokeCertEndpoint)
|
||||
// TODO(roland): randomly use the certificate key to sign the request instead of
|
||||
// the account key
|
||||
jws, err := ctx.signKeyIDV2Request(revokeJSON, revokeURL)
|
||||
jws, err := c.signKeyIDV2Request(revokeJSON, revokeURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -639,7 +639,7 @@ func revokeCertificate(s *State, ctx *context) error {
|
|||
resp, err := s.post(
|
||||
revokeURL,
|
||||
requestPayload,
|
||||
ctx.ns,
|
||||
c.ns,
|
||||
"/acme/revoke-cert",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
@ -118,9 +119,11 @@ func main() {
|
|||
"HTTPOneAddrs, TLSALPNOneAddrs or DNSAddrs\n")
|
||||
}
|
||||
|
||||
go cmd.CatchSignals(nil, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go cmd.CatchSignals(cancel)
|
||||
|
||||
err = s.Run(
|
||||
ctx,
|
||||
config.HTTPOneAddrs,
|
||||
config.TLSALPNOneAddrs,
|
||||
config.DNSAddrs,
|
||||
|
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
|
@ -15,14 +16,12 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gopkg.in/go-jose/go-jose.v2"
|
||||
|
@ -54,7 +53,7 @@ func (acct *account) update(finalizedOrders, certs []string) {
|
|||
acct.certs = append(acct.certs, certs...)
|
||||
}
|
||||
|
||||
type context struct {
|
||||
type acmeCache struct {
|
||||
// The current V2 account (may be nil for legacy load generation)
|
||||
acct *account
|
||||
// Pending orders waiting for authorization challenge validation
|
||||
|
@ -70,12 +69,12 @@ type context struct {
|
|||
ns *nonceSource
|
||||
}
|
||||
|
||||
// signEmbeddedV2Request signs the provided request data using the context's
|
||||
// signEmbeddedV2Request signs the provided request data using the acmeCache's
|
||||
// account's private key. The provided URL is set as a protected header per ACME
|
||||
// v2 JWS standards. The resulting JWS contains an **embedded** JWK - this makes
|
||||
// this function primarily applicable to new account requests where no key ID is
|
||||
// known.
|
||||
func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||
func (c *acmeCache) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||
// Create a signing key for the account's private key
|
||||
signingKey := jose.SigningKey{
|
||||
Key: c.acct.key,
|
||||
|
@ -101,13 +100,13 @@ func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebS
|
|||
return signed, nil
|
||||
}
|
||||
|
||||
// signKeyIDV2Request signs the provided request data using the context's
|
||||
// signKeyIDV2Request signs the provided request data using the acmeCache's
|
||||
// account's private key. The provided URL is set as a protected header per ACME
|
||||
// v2 JWS standards. The resulting JWS contains a Key ID header that is
|
||||
// populated using the context's account's ID. This is the default JWS signing
|
||||
// populated using the acmeCache's account's ID. This is the default JWS signing
|
||||
// style for ACME v2 requests and should be used everywhere but where the key ID
|
||||
// is unknown (e.g. new-account requests where an account doesn't exist yet).
|
||||
func (c *context) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||
func (c *acmeCache) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||
// Create a JWK with the account's private key and key ID
|
||||
jwk := &jose.JSONWebKey{
|
||||
Key: c.acct.key,
|
||||
|
@ -168,7 +167,7 @@ type State struct {
|
|||
realIP string
|
||||
certKey *ecdsa.PrivateKey
|
||||
|
||||
operations []func(*State, *context) error
|
||||
operations []func(*State, *acmeCache) error
|
||||
|
||||
rMu sync.RWMutex
|
||||
|
||||
|
@ -349,6 +348,7 @@ func New(
|
|||
|
||||
// Run runs the WFE load-generator
|
||||
func (s *State) Run(
|
||||
ctx context.Context,
|
||||
httpOneAddrs []string,
|
||||
tlsALPNOneAddrs []string,
|
||||
dnsAddrs []string,
|
||||
|
@ -387,8 +387,6 @@ func (s *State) Run(
|
|||
|
||||
// Run sending loop
|
||||
stop := make(chan bool, 1)
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
fmt.Println("[+] Beginning execution plan")
|
||||
i := int64(0)
|
||||
go func() {
|
||||
|
@ -429,8 +427,8 @@ func (s *State) Run(
|
|||
select {
|
||||
case <-time.After(p.Runtime):
|
||||
fmt.Println("[+] Execution plan finished")
|
||||
case sig := <-sigs:
|
||||
fmt.Printf("[!] Execution plan interrupted: %s caught\n", sig.String())
|
||||
case <-ctx.Done():
|
||||
fmt.Println("[!] Execution plan cancelled")
|
||||
}
|
||||
stop <- true
|
||||
fmt.Println("[+] Waiting for pending flows to finish before killing challenge server")
|
||||
|
@ -586,19 +584,19 @@ func (s *State) addAccount(acct *account) {
|
|||
|
||||
func (s *State) sendCall() {
|
||||
defer s.wg.Done()
|
||||
ctx := &context{}
|
||||
c := &acmeCache{}
|
||||
|
||||
for _, op := range s.operations {
|
||||
err := op(s, ctx)
|
||||
err := op(s, c)
|
||||
if err != nil {
|
||||
method := runtime.FuncForPC(reflect.ValueOf(op).Pointer()).Name()
|
||||
fmt.Printf("[FAILED] %s: %s\n", method, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
// If the context's V2 account isn't nil, update it based on the context's
|
||||
// If the acmeCache's V2 account isn't nil, update it based on the cache's
|
||||
// finalizedOrders and certs.
|
||||
if ctx.acct != nil {
|
||||
ctx.acct.update(ctx.finalizedOrders, ctx.certs)
|
||||
if c.acct != nil {
|
||||
c.acct.update(c.finalizedOrders, c.certs)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
@ -186,11 +187,19 @@ scan:
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *mailSrv) serveSMTP(l net.Listener) error {
|
||||
func (srv *mailSrv) serveSMTP(ctx context.Context, l net.Listener) error {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
// If the accept call returned an error because the listener has been
|
||||
// closed, then the context should have been canceled too. In that case,
|
||||
// ignore the error.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
go srv.handleConn(conn)
|
||||
}
|
||||
|
@ -233,10 +242,10 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
go cmd.CatchSignals(nil, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err = srv.serveSMTP(l)
|
||||
if err != nil {
|
||||
log.Fatalln(err, "Failed to accept connection")
|
||||
}
|
||||
go cmd.FailOnError(srv.serveSMTP(ctx, l), "Failed to accept connection")
|
||||
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/letsencrypt/boulder/cmd"
|
||||
"github.com/letsencrypt/boulder/core"
|
||||
|
@ -93,9 +94,23 @@ func main() {
|
|||
http.HandleFunc("/clear", srv.handleClear)
|
||||
http.HandleFunc("/query", srv.handleQuery)
|
||||
|
||||
// The gosec linter complains that timeouts cannot be set here. That's fine,
|
||||
// because this is test-only code.
|
||||
////nolint:gosec
|
||||
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
|
||||
cmd.CatchSignals(nil, nil)
|
||||
s := http.Server{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
Addr: *listenAddr,
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := s.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
cmd.FailOnError(err, "Running TLS server")
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
s.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
cmd.WaitForSignal()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue