Fix non-gRPC process cleanup and exit (#6808)
Although #6771 significantly cleaned up how gRPC services stop and clean up, it didn't make any changes to our HTTP servers or our non-server (e.g. crl-updater, log-validator) processes. This change finishes the work. Add a new helper method cmd.WaitForSignal, which simply blocks until one of the three signals we care about is received. This easily replaces all calls to cmd.CatchSignals which passed `nil` as the callback argument, with the added advantage that it doesn't call os.Exit() and therefore allows deferred cleanup functions to execute. This new function is intended to be the last line of main(), allowing the whole process to exit once it returns. Reimplement cmd.CatchSignals as a thin wrapper around cmd.WaitForSignal, but with the added callback functionality. Also remove the os.Exit() call from CatchSignals, so that the main goroutine is allowed to finish whatever it's doing, call deferred functions, and exit naturally. Update all of our non-gRPC binaries to use one of these two functions. The vast majority use WaitForSignal, as they run their main processing loop in a background goroutine. A few (particularly those that can run either in run-once or in daemonized mode) still use CatchSignals, since their primary processing happens directly on the main goroutine. The changes to //test/load-generator are the most invasive, simply because that binary needed to have a context plumbed into it for proper cancellation, but it already had a custom struct type named "context" which needed to be renamed to avoid shadowing. Fixes https://github.com/letsencrypt/boulder/issues/6794
This commit is contained in:
parent
98fa0f07b4
commit
bd1d27b8e8
|
@ -535,9 +535,6 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// The gosec linter complains that ReadHeaderTimeout is not set. That's fine,
|
|
||||||
// because that field inherits its value from ReadTimeout.
|
|
||||||
////nolint:gosec
|
|
||||||
tlsSrv := http.Server{
|
tlsSrv := http.Server{
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 30 * time.Second,
|
||||||
WriteTimeout: 120 * time.Second,
|
WriteTimeout: 120 * time.Second,
|
||||||
|
@ -555,20 +552,18 @@ func main() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan bool)
|
// When main is ready to exit (because it has received a shutdown signal),
|
||||||
go cmd.CatchSignals(logger, func() {
|
// gracefully shutdown the servers. Calling these shutdown functions causes
|
||||||
|
// ListenAndServe() and ListenAndServeTLS() to immediately return, then waits
|
||||||
|
// for any lingering connection-handling goroutines to finish their work.
|
||||||
|
defer func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration)
|
ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_ = srv.Shutdown(ctx)
|
_ = srv.Shutdown(ctx)
|
||||||
_ = tlsSrv.Shutdown(ctx)
|
_ = tlsSrv.Shutdown(ctx)
|
||||||
done <- true
|
}()
|
||||||
})
|
|
||||||
|
|
||||||
// https://godoc.org/net/http#Server.Shutdown:
|
cmd.WaitForSignal()
|
||||||
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
|
||||||
// immediately return ErrServerClosed. Make sure the program doesn't exit and
|
|
||||||
// waits instead for Shutdown to return.
|
|
||||||
<-done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package notmain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
@ -165,15 +166,17 @@ func main() {
|
||||||
cmd.FailOnError(err, "Failed to create crl-updater")
|
cmd.FailOnError(err, "Failed to create crl-updater")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
go cmd.CatchSignals(logger, cancel)
|
go cmd.CatchSignals(cancel)
|
||||||
|
|
||||||
if *runOnce {
|
if *runOnce {
|
||||||
err = u.Tick(ctx, clk.Now())
|
err = u.Tick(ctx, clk.Now())
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
cmd.FailOnError(err, "")
|
cmd.FailOnError(err, "")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err = u.Run(ctx)
|
err = u.Run(ctx)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
logger.Err(err.Error())
|
cmd.FailOnError(err, "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -803,6 +803,11 @@ func main() {
|
||||||
defer logger.AuditPanic()
|
defer logger.AuditPanic()
|
||||||
logger.Info(cmd.VersionString())
|
logger.Info(cmd.VersionString())
|
||||||
|
|
||||||
|
if *daemon && c.Mailer.Frequency.Duration == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "mailer.frequency is not set in the JSON config")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
if *certLimit > 0 {
|
if *certLimit > 0 {
|
||||||
c.Mailer.CertLimit = *certLimit
|
c.Mailer.CertLimit = *certLimit
|
||||||
}
|
}
|
||||||
|
@ -913,32 +918,27 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
go cmd.CatchSignals(cancel)
|
||||||
|
|
||||||
go cmd.CatchSignals(logger, func() {
|
|
||||||
cancel()
|
|
||||||
select {} // wait for the `findExpiringCertificates` calls below to exit
|
|
||||||
})
|
|
||||||
|
|
||||||
if *daemon {
|
if *daemon {
|
||||||
if c.Mailer.Frequency.Duration == 0 {
|
|
||||||
fmt.Fprintln(os.Stderr, "mailer.Frequency is not set in the JSON config")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
t := time.NewTicker(c.Mailer.Frequency.Duration)
|
t := time.NewTicker(c.Mailer.Frequency.Duration)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
err = m.findExpiringCertificates(ctx)
|
err = m.findExpiringCertificates(ctx)
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||||
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
os.Exit(0)
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = m.findExpiringCertificates(ctx)
|
err = m.findExpiringCertificates(ctx)
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
cmd.FailOnError(err, "expiration-mailer has failed")
|
cmd.FailOnError(err, "expiration-mailer has failed")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -199,7 +199,7 @@ func main() {
|
||||||
tailers = append(tailers, t)
|
tailers = append(tailers, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.CatchSignals(logger, func() {
|
defer func() {
|
||||||
for _, t := range tailers {
|
for _, t := range tailers {
|
||||||
// The tail module seems to have a race condition that will generate
|
// The tail module seems to have a race condition that will generate
|
||||||
// errors like this on shutdown:
|
// errors like this on shutdown:
|
||||||
|
@ -211,7 +211,9 @@ func main() {
|
||||||
_ = t.Stop()
|
_ = t.Stop()
|
||||||
t.Cleanup()
|
t.Cleanup()
|
||||||
}
|
}
|
||||||
})
|
}()
|
||||||
|
|
||||||
|
cmd.WaitForSignal()
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -227,25 +227,24 @@ as generated by Boulder's ceremony command.
|
||||||
Handler: m,
|
Handler: m,
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan bool)
|
|
||||||
go cmd.CatchSignals(logger, func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(),
|
|
||||||
c.OCSPResponder.ShutdownStopTimeout.Duration)
|
|
||||||
defer cancel()
|
|
||||||
_ = srv.Shutdown(ctx)
|
|
||||||
done <- true
|
|
||||||
})
|
|
||||||
|
|
||||||
err = srv.ListenAndServe()
|
err = srv.ListenAndServe()
|
||||||
if err != nil && err != http.ErrServerClosed {
|
if err != nil && err != http.ErrServerClosed {
|
||||||
cmd.FailOnError(err, "Running HTTP server")
|
cmd.FailOnError(err, "Running HTTP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://godoc.org/net/http#Server.Shutdown:
|
// When main is ready to exit (because it has received a shutdown signal),
|
||||||
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
// gracefully shutdown the servers. Calling these shutdown functions causes
|
||||||
// immediately return ErrServerClosed. Make sure the program doesn't exit and
|
// ListenAndServe() to immediately return, cleaning up the server goroutines
|
||||||
// waits instead for Shutdown to return.
|
// as well, then waits for any lingering connection-handing goroutines to
|
||||||
<-done
|
// finish and clean themselves up.
|
||||||
|
defer func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(),
|
||||||
|
c.OCSPResponder.ShutdownStopTimeout.Duration)
|
||||||
|
defer cancel()
|
||||||
|
srv.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmd.WaitForSignal()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ocspMux partially implements the interface defined for http.ServeMux but doesn't implement
|
// ocspMux partially implements the interface defined for http.ServeMux but doesn't implement
|
||||||
|
|
26
cmd/shell.go
26
cmd/shell.go
|
@ -418,18 +418,26 @@ func VersionString() string {
|
||||||
return fmt.Sprintf("Versions: %s=(%s %s) Golang=(%s) BuildHost=(%s)", command(), core.GetBuildID(), core.GetBuildTime(), runtime.Version(), core.GetBuildHost())
|
return fmt.Sprintf("Versions: %s=(%s %s) Golang=(%s) BuildHost=(%s)", command(), core.GetBuildID(), core.GetBuildTime(), runtime.Version(), core.GetBuildHost())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CatchSignals catches SIGTERM, SIGINT, SIGHUP and executes a callback
|
// CatchSignals blocks until a SIGTERM, SIGINT, or SIGHUP is received, then
|
||||||
// method before exiting
|
// executes the given callback. The callback should not block, it should simply
|
||||||
func CatchSignals(logger blog.Logger, callback func()) {
|
// signal other goroutines (particularly the main goroutine) to clean themselves
|
||||||
|
// up and exit. This function is intended to be called in its own goroutine,
|
||||||
|
// while the main goroutine waits for an indication that the other goroutines
|
||||||
|
// have exited cleanly.
|
||||||
|
func CatchSignals(callback func()) {
|
||||||
|
WaitForSignal()
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForSignal blocks until a SIGTERM, SIGINT, or SIGHUP is received. It then
|
||||||
|
// returns, allowing execution to resume, generally allowing a main() function
|
||||||
|
// to return and trigger and deferred cleanup functions. This function is
|
||||||
|
// intended to be called directly from the main goroutine, while a gRPC or HTTP
|
||||||
|
// server runs in a background goroutine.
|
||||||
|
func WaitForSignal() {
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
signal.Notify(sigChan, syscall.SIGINT)
|
||||||
signal.Notify(sigChan, syscall.SIGHUP)
|
signal.Notify(sigChan, syscall.SIGHUP)
|
||||||
|
|
||||||
<-sigChan
|
<-sigChan
|
||||||
if callback != nil {
|
|
||||||
callback()
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||||
"github.com/jmhodges/clock"
|
"github.com/jmhodges/clock"
|
||||||
|
@ -175,15 +172,10 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R
|
||||||
// Start a goroutine which listens for a termination signal, and then
|
// Start a goroutine which listens for a termination signal, and then
|
||||||
// gracefully stops the gRPC server. This in turn causes the start() function
|
// gracefully stops the gRPC server. This in turn causes the start() function
|
||||||
// to exit, allowing its caller (generally a main() function) to exit.
|
// to exit, allowing its caller (generally a main() function) to exit.
|
||||||
go func() {
|
go cmd.CatchSignals(func() {
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
|
||||||
signal.Notify(sigChan, syscall.SIGHUP)
|
|
||||||
<-sigChan
|
|
||||||
healthSrv.Shutdown()
|
healthSrv.Shutdown()
|
||||||
server.GracefulStop()
|
server.GracefulStop()
|
||||||
}()
|
})
|
||||||
|
|
||||||
return start, nil
|
return start, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/letsencrypt/boulder/akamai"
|
"github.com/letsencrypt/boulder/akamai"
|
||||||
"github.com/letsencrypt/boulder/cmd"
|
"github.com/letsencrypt/boulder/cmd"
|
||||||
|
@ -92,9 +93,23 @@ func main() {
|
||||||
w.Write(resp)
|
w.Write(resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
// The gosec linter complains that timeouts cannot be set here. That's fine,
|
s := http.Server{
|
||||||
// because this is test-only code.
|
ReadTimeout: 30 * time.Second,
|
||||||
////nolint:gosec
|
Addr: *listenAddr,
|
||||||
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
|
}
|
||||||
cmd.CatchSignals(nil, nil)
|
|
||||||
|
go func() {
|
||||||
|
err := s.ListenAndServe()
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
cmd.FailOnError(err, "Running TLS server")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
s.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmd.WaitForSignal()
|
||||||
}
|
}
|
||||||
|
|
|
@ -257,5 +257,5 @@ func main() {
|
||||||
for _, p := range c.Personalities {
|
for _, p := range c.Personalities {
|
||||||
go runPersonality(p)
|
go runPersonality(p)
|
||||||
}
|
}
|
||||||
cmd.CatchSignals(nil, nil)
|
cmd.WaitForSignal()
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ import (
|
||||||
var (
|
var (
|
||||||
// stringToOperation maps a configured plan action to a function that can
|
// stringToOperation maps a configured plan action to a function that can
|
||||||
// operate on a state/context.
|
// operate on a state/context.
|
||||||
stringToOperation = map[string]func(*State, *context) error{
|
stringToOperation = map[string]func(*State, *acmeCache) error{
|
||||||
"newAccount": newAccount,
|
"newAccount": newAccount,
|
||||||
"getAccount": getAccount,
|
"getAccount": getAccount,
|
||||||
"newOrder": newOrder,
|
"newOrder": newOrder,
|
||||||
|
@ -60,8 +60,8 @@ type OrderJSON struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccount takes a randomly selected v2 account from `state.accts` and puts it
|
// getAccount takes a randomly selected v2 account from `state.accts` and puts it
|
||||||
// into `ctx.acct`. The context `nonceSource` is also populated as convenience.
|
// into `c.acct`. The context `nonceSource` is also populated as convenience.
|
||||||
func getAccount(s *State, ctx *context) error {
|
func getAccount(s *State, c *acmeCache) error {
|
||||||
s.rMu.RLock()
|
s.rMu.RLock()
|
||||||
defer s.rMu.RUnlock()
|
defer s.rMu.RUnlock()
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ func getAccount(s *State, ctx *context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select a random account from the state and put it into the context
|
// Select a random account from the state and put it into the context
|
||||||
ctx.acct = s.accts[mrand.Intn(len(s.accts))]
|
c.acct = s.accts[mrand.Intn(len(s.accts))]
|
||||||
ctx.ns = &nonceSource{s: s}
|
c.ns = &nonceSource{s: s}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,11 +81,11 @@ func getAccount(s *State, ctx *context) error {
|
||||||
// then `newAccount` puts an existing account from the state into the context,
|
// then `newAccount` puts an existing account from the state into the context,
|
||||||
// otherwise it creates a new account and puts it into both the state and the
|
// otherwise it creates a new account and puts it into both the state and the
|
||||||
// context.
|
// context.
|
||||||
func newAccount(s *State, ctx *context) error {
|
func newAccount(s *State, c *acmeCache) error {
|
||||||
// Check the max regs and if exceeded, just return an existing account instead
|
// Check the max regs and if exceeded, just return an existing account instead
|
||||||
// of creating a new one.
|
// of creating a new one.
|
||||||
if s.maxRegs != 0 && s.numAccts() >= s.maxRegs {
|
if s.maxRegs != 0 && s.numAccts() >= s.maxRegs {
|
||||||
return getAccount(s, ctx)
|
return getAccount(s, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a random signing key
|
// Create a random signing key
|
||||||
|
@ -93,10 +93,10 @@ func newAccount(s *State, ctx *context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.acct = &account{
|
c.acct = &account{
|
||||||
key: signKey,
|
key: signKey,
|
||||||
}
|
}
|
||||||
ctx.ns = &nonceSource{s: s}
|
c.ns = &nonceSource{s: s}
|
||||||
|
|
||||||
// Prepare an account registration message body
|
// Prepare an account registration message body
|
||||||
reqBody := struct {
|
reqBody := struct {
|
||||||
|
@ -117,7 +117,7 @@ func newAccount(s *State, ctx *context) error {
|
||||||
// Sign the new account registration body using a JWS with an embedded JWK
|
// Sign the new account registration body using a JWS with an embedded JWK
|
||||||
// because we do not have a key ID from the server yet.
|
// because we do not have a key ID from the server yet.
|
||||||
newAccountURL := s.directory.EndpointURL(acme.NewAccountEndpoint)
|
newAccountURL := s.directory.EndpointURL(acme.NewAccountEndpoint)
|
||||||
jws, err := ctx.signEmbeddedV2Request(reqBodyStr, newAccountURL)
|
jws, err := c.signEmbeddedV2Request(reqBodyStr, newAccountURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func newAccount(s *State, ctx *context) error {
|
||||||
resp, err := s.post(
|
resp, err := s.post(
|
||||||
newAccountURL,
|
newAccountURL,
|
||||||
bodyBuf,
|
bodyBuf,
|
||||||
ctx.ns,
|
c.ns,
|
||||||
string(acme.NewAccountEndpoint),
|
string(acme.NewAccountEndpoint),
|
||||||
http.StatusCreated)
|
http.StatusCreated)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -140,10 +140,10 @@ func newAccount(s *State, ctx *context) error {
|
||||||
if locHeader == "" {
|
if locHeader == "" {
|
||||||
return fmt.Errorf("%s, bad response - no Location header with account ID", newAccountURL)
|
return fmt.Errorf("%s, bad response - no Location header with account ID", newAccountURL)
|
||||||
}
|
}
|
||||||
ctx.acct.id = locHeader
|
c.acct.id = locHeader
|
||||||
|
|
||||||
// Add the account to the state
|
// Add the account to the state
|
||||||
s.addAccount(ctx.acct)
|
s.addAccount(c.acct)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ func randDomain(base string) string {
|
||||||
|
|
||||||
// newOrder creates a new pending order object for a random set of domains using
|
// newOrder creates a new pending order object for a random set of domains using
|
||||||
// the context's account.
|
// the context's account.
|
||||||
func newOrder(s *State, ctx *context) error {
|
func newOrder(s *State, c *acmeCache) error {
|
||||||
// Pick a random number of names within the constraints of the maxNamesPerCert
|
// Pick a random number of names within the constraints of the maxNamesPerCert
|
||||||
// parameter
|
// parameter
|
||||||
orderSize := 1 + mrand.Intn(s.maxNamesPerCert-1)
|
orderSize := 1 + mrand.Intn(s.maxNamesPerCert-1)
|
||||||
|
@ -187,7 +187,7 @@ func newOrder(s *State, ctx *context) error {
|
||||||
|
|
||||||
// Sign the new order request with the context account's key/key ID
|
// Sign the new order request with the context account's key/key ID
|
||||||
newOrderURL := s.directory.EndpointURL(acme.NewOrderEndpoint)
|
newOrderURL := s.directory.EndpointURL(acme.NewOrderEndpoint)
|
||||||
jws, err := ctx.signKeyIDV2Request(initOrderStr, newOrderURL)
|
jws, err := c.signKeyIDV2Request(initOrderStr, newOrderURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -196,7 +196,7 @@ func newOrder(s *State, ctx *context) error {
|
||||||
resp, err := s.post(
|
resp, err := s.post(
|
||||||
newOrderURL,
|
newOrderURL,
|
||||||
bodyBuf,
|
bodyBuf,
|
||||||
ctx.ns,
|
c.ns,
|
||||||
string(acme.NewOrderEndpoint),
|
string(acme.NewOrderEndpoint),
|
||||||
http.StatusCreated)
|
http.StatusCreated)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -223,24 +223,24 @@ func newOrder(s *State, ctx *context) error {
|
||||||
orderJSON.URL = orderURL
|
orderJSON.URL = orderURL
|
||||||
|
|
||||||
// Store the pending order in the context
|
// Store the pending order in the context
|
||||||
ctx.pendingOrders = append(ctx.pendingOrders, &orderJSON)
|
c.pendingOrders = append(c.pendingOrders, &orderJSON)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// popPendingOrder *removes* a random pendingOrder from the context, returning
|
// popPendingOrder *removes* a random pendingOrder from the context, returning
|
||||||
// it.
|
// it.
|
||||||
func popPendingOrder(ctx *context) *OrderJSON {
|
func popPendingOrder(c *acmeCache) *OrderJSON {
|
||||||
orderIndex := mrand.Intn(len(ctx.pendingOrders))
|
orderIndex := mrand.Intn(len(c.pendingOrders))
|
||||||
order := ctx.pendingOrders[orderIndex]
|
order := c.pendingOrders[orderIndex]
|
||||||
ctx.pendingOrders = append(ctx.pendingOrders[:orderIndex], ctx.pendingOrders[orderIndex+1:]...)
|
c.pendingOrders = append(c.pendingOrders[:orderIndex], c.pendingOrders[orderIndex+1:]...)
|
||||||
return order
|
return order
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAuthorization fetches an authorization by GET-ing the provided URL. It
|
// getAuthorization fetches an authorization by GET-ing the provided URL. It
|
||||||
// records the latency and result of the GET operation in the state.
|
// records the latency and result of the GET operation in the state.
|
||||||
func getAuthorization(s *State, ctx *context, url string) (*core.Authorization, error) {
|
func getAuthorization(s *State, c *acmeCache, url string) (*core.Authorization, error) {
|
||||||
latencyTag := "/acme/authz/{ID}"
|
latencyTag := "/acme/authz/{ID}"
|
||||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
resp, err := postAsGet(s, c, url, latencyTag)
|
||||||
// If there was an error, note the state and return
|
// If there was an error, note the state and return
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||||
|
@ -269,7 +269,7 @@ func getAuthorization(s *State, ctx *context, url string) (*core.Authorization,
|
||||||
// HTTP-01 challenge using the context's account and the state's challenge
|
// HTTP-01 challenge using the context's account and the state's challenge
|
||||||
// server. Aftering POSTing the authorization's HTTP-01 challenge the
|
// server. Aftering POSTing the authorization's HTTP-01 challenge the
|
||||||
// authorization will be polled waiting for a state change.
|
// authorization will be polled waiting for a state change.
|
||||||
func completeAuthorization(authz *core.Authorization, s *State, ctx *context) error {
|
func completeAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
|
||||||
// Skip if the authz isn't pending
|
// Skip if the authz isn't pending
|
||||||
if authz.Status != core.StatusPending {
|
if authz.Status != core.StatusPending {
|
||||||
return nil
|
return nil
|
||||||
|
@ -283,7 +283,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the key authorization from the context account's key
|
// Compute the key authorization from the context account's key
|
||||||
jwk := &jose.JSONWebKey{Key: &ctx.acct.key.PublicKey}
|
jwk := &jose.JSONWebKey{Key: &c.acct.key.PublicKey}
|
||||||
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -311,7 +311,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the Challenge POST body
|
// Prepare the Challenge POST body
|
||||||
jws, err := ctx.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
|
jws, err := c.signKeyIDV2Request([]byte(`{}`), chalToSolve.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -320,7 +320,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
||||||
resp, err := s.post(
|
resp, err := s.post(
|
||||||
chalToSolve.URL,
|
chalToSolve.URL,
|
||||||
requestPayload,
|
requestPayload,
|
||||||
ctx.ns,
|
c.ns,
|
||||||
"/acme/challenge/{ID}", // We want all challenge POST latencies to be grouped
|
"/acme/challenge/{ID}", // We want all challenge POST latencies to be grouped
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
)
|
)
|
||||||
|
@ -337,7 +337,7 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
||||||
|
|
||||||
// Poll the authorization waiting for the challenge response to be recorded in
|
// Poll the authorization waiting for the challenge response to be recorded in
|
||||||
// a change of state. The polling may sleep and retry a few times if required
|
// a change of state. The polling may sleep and retry a few times if required
|
||||||
err = pollAuthorization(authz, s, ctx)
|
err = pollAuthorization(authz, s, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -351,11 +351,11 @@ func completeAuthorization(authz *core.Authorization, s *State, ctx *context) er
|
||||||
// be valid. If the status is invalid, or if three GETs do not produce the
|
// be valid. If the status is invalid, or if three GETs do not produce the
|
||||||
// correct authorization state an error is returned. If no error is returned
|
// correct authorization state an error is returned. If no error is returned
|
||||||
// then the authorization is valid and ready.
|
// then the authorization is valid and ready.
|
||||||
func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error {
|
func pollAuthorization(authz *core.Authorization, s *State, c *acmeCache) error {
|
||||||
authzURL := authz.ID
|
authzURL := authz.ID
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
// Fetch the authz by its URL
|
// Fetch the authz by its URL
|
||||||
authz, err := getAuthorization(s, ctx, authzURL)
|
authz, err := getAuthorization(s, c, authzURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -377,25 +377,25 @@ func pollAuthorization(authz *core.Authorization, s *State, ctx *context) error
|
||||||
// authorization's HTTP-01 challenge using the context's account, and finally
|
// authorization's HTTP-01 challenge using the context's account, and finally
|
||||||
// placing the now-ready-to-be-finalized order into the context's list of
|
// placing the now-ready-to-be-finalized order into the context's list of
|
||||||
// fulfilled orders.
|
// fulfilled orders.
|
||||||
func fulfillOrder(s *State, ctx *context) error {
|
func fulfillOrder(s *State, c *acmeCache) error {
|
||||||
// There must be at least one pending order in the context to fulfill
|
// There must be at least one pending order in the context to fulfill
|
||||||
if len(ctx.pendingOrders) == 0 {
|
if len(c.pendingOrders) == 0 {
|
||||||
return errors.New("no pending orders to fulfill")
|
return errors.New("no pending orders to fulfill")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get an order to fulfill from the context
|
// Get an order to fulfill from the context
|
||||||
order := popPendingOrder(ctx)
|
order := popPendingOrder(c)
|
||||||
|
|
||||||
// Each of its authorizations need to be processed
|
// Each of its authorizations need to be processed
|
||||||
for _, url := range order.Authorizations {
|
for _, url := range order.Authorizations {
|
||||||
// Fetch the authz by its URL
|
// Fetch the authz by its URL
|
||||||
authz, err := getAuthorization(s, ctx, url)
|
authz, err := getAuthorization(s, c, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete the authorization by solving a challenge
|
// Complete the authorization by solving a challenge
|
||||||
err = completeAuthorization(authz, s, ctx)
|
err = completeAuthorization(authz, s, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -403,16 +403,16 @@ func fulfillOrder(s *State, ctx *context) error {
|
||||||
|
|
||||||
// Once all of the authorizations have been fulfilled the order is fulfilled
|
// Once all of the authorizations have been fulfilled the order is fulfilled
|
||||||
// and ready for future finalization.
|
// and ready for future finalization.
|
||||||
ctx.fulfilledOrders = append(ctx.fulfilledOrders, order.URL)
|
c.fulfilledOrders = append(c.fulfilledOrders, order.URL)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrder GETs an order by URL, returning an OrderJSON object. It tracks the
|
// getOrder GETs an order by URL, returning an OrderJSON object. It tracks the
|
||||||
// latency of the GET operation in the provided state.
|
// latency of the GET operation in the provided state.
|
||||||
func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
|
func getOrder(s *State, c *acmeCache, url string) (*OrderJSON, error) {
|
||||||
latencyTag := "/acme/order/{ID}"
|
latencyTag := "/acme/order/{ID}"
|
||||||
// POST-as-GET the order URL
|
// POST-as-GET the order URL
|
||||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
resp, err := postAsGet(s, c, url, latencyTag)
|
||||||
// If there was an error, track that result
|
// If there was an error, track that result
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||||
|
@ -440,10 +440,10 @@ func getOrder(s *State, ctx *context, url string) (*OrderJSON, error) {
|
||||||
// valid such that a certificate URL for the order is known. Three attempts are
|
// valid such that a certificate URL for the order is known. Three attempts are
|
||||||
// made to check the order status, sleeping 3s between each. If these attempts
|
// made to check the order status, sleeping 3s between each. If these attempts
|
||||||
// expire without the status becoming valid an error is returned.
|
// expire without the status becoming valid an error is returned.
|
||||||
func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, error) {
|
func pollOrderForCert(order *OrderJSON, s *State, c *acmeCache) (*OrderJSON, error) {
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
// Fetch the order by its URL
|
// Fetch the order by its URL
|
||||||
order, err := getOrder(s, ctx, order.URL)
|
order, err := getOrder(s, c, order.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -463,10 +463,10 @@ func pollOrderForCert(order *OrderJSON, s *State, ctx *context) (*OrderJSON, err
|
||||||
|
|
||||||
// popFulfilledOrder **removes** a fulfilled order from the context, returning
|
// popFulfilledOrder **removes** a fulfilled order from the context, returning
|
||||||
// it. Fulfilled orders have all of their authorizations satisfied.
|
// it. Fulfilled orders have all of their authorizations satisfied.
|
||||||
func popFulfilledOrder(ctx *context) string {
|
func popFulfilledOrder(c *acmeCache) string {
|
||||||
orderIndex := mrand.Intn(len(ctx.fulfilledOrders))
|
orderIndex := mrand.Intn(len(c.fulfilledOrders))
|
||||||
order := ctx.fulfilledOrders[orderIndex]
|
order := c.fulfilledOrders[orderIndex]
|
||||||
ctx.fulfilledOrders = append(ctx.fulfilledOrders[:orderIndex], ctx.fulfilledOrders[orderIndex+1:]...)
|
c.fulfilledOrders = append(c.fulfilledOrders[:orderIndex], c.fulfilledOrders[orderIndex+1:]...)
|
||||||
return order
|
return order
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -475,15 +475,15 @@ func popFulfilledOrder(ctx *context) string {
|
||||||
// `certKey`. The order is then polled for the status to change to valid so that
|
// `certKey`. The order is then polled for the status to change to valid so that
|
||||||
// the certificate URL can be added to the context. The context's `certs` list
|
// the certificate URL can be added to the context. The context's `certs` list
|
||||||
// is updated with the URL for the order's certificate.
|
// is updated with the URL for the order's certificate.
|
||||||
func finalizeOrder(s *State, ctx *context) error {
|
func finalizeOrder(s *State, c *acmeCache) error {
|
||||||
// There must be at least one fulfilled order in the context
|
// There must be at least one fulfilled order in the context
|
||||||
if len(ctx.fulfilledOrders) < 1 {
|
if len(c.fulfilledOrders) < 1 {
|
||||||
return errors.New("No fulfilled orders in the context ready to be finalized")
|
return errors.New("No fulfilled orders in the context ready to be finalized")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pop a fulfilled order to process, and then GET its contents
|
// Pop a fulfilled order to process, and then GET its contents
|
||||||
orderID := popFulfilledOrder(ctx)
|
orderID := popFulfilledOrder(c)
|
||||||
order, err := getOrder(s, ctx, orderID)
|
order, err := getOrder(s, c, orderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -519,7 +519,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sign the request body with the context's account key/keyID
|
// Sign the request body with the context's account key/keyID
|
||||||
jws, err := ctx.signKeyIDV2Request([]byte(request), finalizeURL)
|
jws, err := c.signKeyIDV2Request([]byte(request), finalizeURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -528,7 +528,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
||||||
resp, err := s.post(
|
resp, err := s.post(
|
||||||
finalizeURL,
|
finalizeURL,
|
||||||
requestPayload,
|
requestPayload,
|
||||||
ctx.ns,
|
c.ns,
|
||||||
"/acme/order/finalize", // We want all order finalizations to be grouped.
|
"/acme/order/finalize", // We want all order finalizations to be grouped.
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
)
|
)
|
||||||
|
@ -544,7 +544,7 @@ func finalizeOrder(s *State, ctx *context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Poll the order waiting for the certificate to be ready
|
// Poll the order waiting for the certificate to be ready
|
||||||
completedOrder, err := pollOrderForCert(order, s, ctx)
|
completedOrder, err := pollOrderForCert(order, s, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -556,8 +556,8 @@ func finalizeOrder(s *State, ctx *context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the certificate URL into the context's list of certificates
|
// Append the certificate URL into the context's list of certificates
|
||||||
ctx.certs = append(ctx.certs, certURL)
|
c.certs = append(c.certs, certURL)
|
||||||
ctx.finalizedOrders = append(ctx.finalizedOrders, order.URL)
|
c.finalizedOrders = append(c.finalizedOrders, order.URL)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,27 +567,27 @@ func finalizeOrder(s *State, ctx *context) error {
|
||||||
// responsible for closing the HTTP response body.
|
// responsible for closing the HTTP response body.
|
||||||
//
|
//
|
||||||
// See RFC 8555 Section 6.3 for more information on POST-as-GET requests.
|
// See RFC 8555 Section 6.3 for more information on POST-as-GET requests.
|
||||||
func postAsGet(s *State, ctx *context, url string, latencyTag string) (*http.Response, error) {
|
func postAsGet(s *State, c *acmeCache, url string, latencyTag string) (*http.Response, error) {
|
||||||
// Create the POST-as-GET request JWS
|
// Create the POST-as-GET request JWS
|
||||||
jws, err := ctx.signKeyIDV2Request([]byte(""), url)
|
jws, err := c.signKeyIDV2Request([]byte(""), url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
requestPayload := []byte(jws.FullSerialize())
|
requestPayload := []byte(jws.FullSerialize())
|
||||||
|
|
||||||
return s.post(url, requestPayload, ctx.ns, latencyTag, http.StatusOK)
|
return s.post(url, requestPayload, c.ns, latencyTag, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func popCertificate(ctx *context) string {
|
func popCertificate(c *acmeCache) string {
|
||||||
certIndex := mrand.Intn(len(ctx.certs))
|
certIndex := mrand.Intn(len(c.certs))
|
||||||
certURL := ctx.certs[certIndex]
|
certURL := c.certs[certIndex]
|
||||||
ctx.certs = append(ctx.certs[:certIndex], ctx.certs[certIndex+1:]...)
|
c.certs = append(c.certs[:certIndex], c.certs[certIndex+1:]...)
|
||||||
return certURL
|
return certURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCert(s *State, ctx *context, url string) ([]byte, error) {
|
func getCert(s *State, c *acmeCache, url string) ([]byte, error) {
|
||||||
latencyTag := "/acme/cert/{serial}"
|
latencyTag := "/acme/cert/{serial}"
|
||||||
resp, err := postAsGet(s, ctx, url, latencyTag)
|
resp, err := postAsGet(s, c, url, latencyTag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
return nil, fmt.Errorf("%s bad response: %s", url, err)
|
||||||
}
|
}
|
||||||
|
@ -599,8 +599,8 @@ func getCert(s *State, ctx *context, url string) ([]byte, error) {
|
||||||
// and sends a revocation request for the certificate to the ACME server.
|
// and sends a revocation request for the certificate to the ACME server.
|
||||||
// The revocation request is signed with the account key rather than the certificate
|
// The revocation request is signed with the account key rather than the certificate
|
||||||
// key.
|
// key.
|
||||||
func revokeCertificate(s *State, ctx *context) error {
|
func revokeCertificate(s *State, c *acmeCache) error {
|
||||||
if len(ctx.certs) < 1 {
|
if len(c.certs) < 1 {
|
||||||
return errors.New("No certificates in the context that can be revoked")
|
return errors.New("No certificates in the context that can be revoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,8 +608,8 @@ func revokeCertificate(s *State, ctx *context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
certURL := popCertificate(ctx)
|
certURL := popCertificate(c)
|
||||||
certPEM, err := getCert(s, ctx, certURL)
|
certPEM, err := getCert(s, c, certURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -630,7 +630,7 @@ func revokeCertificate(s *State, ctx *context) error {
|
||||||
revokeURL := s.directory.EndpointURL(acme.RevokeCertEndpoint)
|
revokeURL := s.directory.EndpointURL(acme.RevokeCertEndpoint)
|
||||||
// TODO(roland): randomly use the certificate key to sign the request instead of
|
// TODO(roland): randomly use the certificate key to sign the request instead of
|
||||||
// the account key
|
// the account key
|
||||||
jws, err := ctx.signKeyIDV2Request(revokeJSON, revokeURL)
|
jws, err := c.signKeyIDV2Request(revokeJSON, revokeURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -639,7 +639,7 @@ func revokeCertificate(s *State, ctx *context) error {
|
||||||
resp, err := s.post(
|
resp, err := s.post(
|
||||||
revokeURL,
|
revokeURL,
|
||||||
requestPayload,
|
requestPayload,
|
||||||
ctx.ns,
|
c.ns,
|
||||||
"/acme/revoke-cert",
|
"/acme/revoke-cert",
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -118,9 +119,11 @@ func main() {
|
||||||
"HTTPOneAddrs, TLSALPNOneAddrs or DNSAddrs\n")
|
"HTTPOneAddrs, TLSALPNOneAddrs or DNSAddrs\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
go cmd.CatchSignals(nil, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go cmd.CatchSignals(cancel)
|
||||||
|
|
||||||
err = s.Run(
|
err = s.Run(
|
||||||
|
ctx,
|
||||||
config.HTTPOneAddrs,
|
config.HTTPOneAddrs,
|
||||||
config.TLSALPNOneAddrs,
|
config.TLSALPNOneAddrs,
|
||||||
config.DNSAddrs,
|
config.DNSAddrs,
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -15,14 +16,12 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gopkg.in/go-jose/go-jose.v2"
|
"gopkg.in/go-jose/go-jose.v2"
|
||||||
|
@ -54,7 +53,7 @@ func (acct *account) update(finalizedOrders, certs []string) {
|
||||||
acct.certs = append(acct.certs, certs...)
|
acct.certs = append(acct.certs, certs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type context struct {
|
type acmeCache struct {
|
||||||
// The current V2 account (may be nil for legacy load generation)
|
// The current V2 account (may be nil for legacy load generation)
|
||||||
acct *account
|
acct *account
|
||||||
// Pending orders waiting for authorization challenge validation
|
// Pending orders waiting for authorization challenge validation
|
||||||
|
@ -70,12 +69,12 @@ type context struct {
|
||||||
ns *nonceSource
|
ns *nonceSource
|
||||||
}
|
}
|
||||||
|
|
||||||
// signEmbeddedV2Request signs the provided request data using the context's
|
// signEmbeddedV2Request signs the provided request data using the acmeCache's
|
||||||
// account's private key. The provided URL is set as a protected header per ACME
|
// account's private key. The provided URL is set as a protected header per ACME
|
||||||
// v2 JWS standards. The resulting JWS contains an **embedded** JWK - this makes
|
// v2 JWS standards. The resulting JWS contains an **embedded** JWK - this makes
|
||||||
// this function primarily applicable to new account requests where no key ID is
|
// this function primarily applicable to new account requests where no key ID is
|
||||||
// known.
|
// known.
|
||||||
func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
func (c *acmeCache) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||||
// Create a signing key for the account's private key
|
// Create a signing key for the account's private key
|
||||||
signingKey := jose.SigningKey{
|
signingKey := jose.SigningKey{
|
||||||
Key: c.acct.key,
|
Key: c.acct.key,
|
||||||
|
@ -101,13 +100,13 @@ func (c *context) signEmbeddedV2Request(data []byte, url string) (*jose.JSONWebS
|
||||||
return signed, nil
|
return signed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// signKeyIDV2Request signs the provided request data using the context's
|
// signKeyIDV2Request signs the provided request data using the acmeCache's
|
||||||
// account's private key. The provided URL is set as a protected header per ACME
|
// account's private key. The provided URL is set as a protected header per ACME
|
||||||
// v2 JWS standards. The resulting JWS contains a Key ID header that is
|
// v2 JWS standards. The resulting JWS contains a Key ID header that is
|
||||||
// populated using the context's account's ID. This is the default JWS signing
|
// populated using the acmeCache's account's ID. This is the default JWS signing
|
||||||
// style for ACME v2 requests and should be used everywhere but where the key ID
|
// style for ACME v2 requests and should be used everywhere but where the key ID
|
||||||
// is unknown (e.g. new-account requests where an account doesn't exist yet).
|
// is unknown (e.g. new-account requests where an account doesn't exist yet).
|
||||||
func (c *context) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
func (c *acmeCache) signKeyIDV2Request(data []byte, url string) (*jose.JSONWebSignature, error) {
|
||||||
// Create a JWK with the account's private key and key ID
|
// Create a JWK with the account's private key and key ID
|
||||||
jwk := &jose.JSONWebKey{
|
jwk := &jose.JSONWebKey{
|
||||||
Key: c.acct.key,
|
Key: c.acct.key,
|
||||||
|
@ -168,7 +167,7 @@ type State struct {
|
||||||
realIP string
|
realIP string
|
||||||
certKey *ecdsa.PrivateKey
|
certKey *ecdsa.PrivateKey
|
||||||
|
|
||||||
operations []func(*State, *context) error
|
operations []func(*State, *acmeCache) error
|
||||||
|
|
||||||
rMu sync.RWMutex
|
rMu sync.RWMutex
|
||||||
|
|
||||||
|
@ -349,6 +348,7 @@ func New(
|
||||||
|
|
||||||
// Run runs the WFE load-generator
|
// Run runs the WFE load-generator
|
||||||
func (s *State) Run(
|
func (s *State) Run(
|
||||||
|
ctx context.Context,
|
||||||
httpOneAddrs []string,
|
httpOneAddrs []string,
|
||||||
tlsALPNOneAddrs []string,
|
tlsALPNOneAddrs []string,
|
||||||
dnsAddrs []string,
|
dnsAddrs []string,
|
||||||
|
@ -387,8 +387,6 @@ func (s *State) Run(
|
||||||
|
|
||||||
// Run sending loop
|
// Run sending loop
|
||||||
stop := make(chan bool, 1)
|
stop := make(chan bool, 1)
|
||||||
sigs := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
fmt.Println("[+] Beginning execution plan")
|
fmt.Println("[+] Beginning execution plan")
|
||||||
i := int64(0)
|
i := int64(0)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -429,8 +427,8 @@ func (s *State) Run(
|
||||||
select {
|
select {
|
||||||
case <-time.After(p.Runtime):
|
case <-time.After(p.Runtime):
|
||||||
fmt.Println("[+] Execution plan finished")
|
fmt.Println("[+] Execution plan finished")
|
||||||
case sig := <-sigs:
|
case <-ctx.Done():
|
||||||
fmt.Printf("[!] Execution plan interrupted: %s caught\n", sig.String())
|
fmt.Println("[!] Execution plan cancelled")
|
||||||
}
|
}
|
||||||
stop <- true
|
stop <- true
|
||||||
fmt.Println("[+] Waiting for pending flows to finish before killing challenge server")
|
fmt.Println("[+] Waiting for pending flows to finish before killing challenge server")
|
||||||
|
@ -586,19 +584,19 @@ func (s *State) addAccount(acct *account) {
|
||||||
|
|
||||||
func (s *State) sendCall() {
|
func (s *State) sendCall() {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
ctx := &context{}
|
c := &acmeCache{}
|
||||||
|
|
||||||
for _, op := range s.operations {
|
for _, op := range s.operations {
|
||||||
err := op(s, ctx)
|
err := op(s, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
method := runtime.FuncForPC(reflect.ValueOf(op).Pointer()).Name()
|
method := runtime.FuncForPC(reflect.ValueOf(op).Pointer()).Name()
|
||||||
fmt.Printf("[FAILED] %s: %s\n", method, err)
|
fmt.Printf("[FAILED] %s: %s\n", method, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If the context's V2 account isn't nil, update it based on the context's
|
// If the acmeCache's V2 account isn't nil, update it based on the cache's
|
||||||
// finalizedOrders and certs.
|
// finalizedOrders and certs.
|
||||||
if ctx.acct != nil {
|
if c.acct != nil {
|
||||||
ctx.acct.update(ctx.finalizedOrders, ctx.certs)
|
c.acct.update(c.finalizedOrders, c.certs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -186,12 +187,20 @@ scan:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *mailSrv) serveSMTP(l net.Listener) error {
|
func (srv *mailSrv) serveSMTP(ctx context.Context, l net.Listener) error {
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// If the accept call returned an error because the listener has been
|
||||||
|
// closed, then the context should have been canceled too. In that case,
|
||||||
|
// ignore the error.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
go srv.handleConn(conn)
|
go srv.handleConn(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,10 +242,10 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go cmd.CatchSignals(nil, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
err = srv.serveSMTP(l)
|
go cmd.FailOnError(srv.serveSMTP(ctx, l), "Failed to accept connection")
|
||||||
if err != nil {
|
|
||||||
log.Fatalln(err, "Failed to accept connection")
|
cmd.WaitForSignal()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/letsencrypt/boulder/cmd"
|
"github.com/letsencrypt/boulder/cmd"
|
||||||
"github.com/letsencrypt/boulder/core"
|
"github.com/letsencrypt/boulder/core"
|
||||||
|
@ -93,9 +94,23 @@ func main() {
|
||||||
http.HandleFunc("/clear", srv.handleClear)
|
http.HandleFunc("/clear", srv.handleClear)
|
||||||
http.HandleFunc("/query", srv.handleQuery)
|
http.HandleFunc("/query", srv.handleQuery)
|
||||||
|
|
||||||
// The gosec linter complains that timeouts cannot be set here. That's fine,
|
s := http.Server{
|
||||||
// because this is test-only code.
|
ReadTimeout: 30 * time.Second,
|
||||||
////nolint:gosec
|
Addr: *listenAddr,
|
||||||
go log.Fatal(http.ListenAndServe(*listenAddr, nil))
|
}
|
||||||
cmd.CatchSignals(nil, nil)
|
|
||||||
|
go func() {
|
||||||
|
err := s.ListenAndServe()
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
cmd.FailOnError(err, "Running TLS server")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
s.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmd.WaitForSignal()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue