From 6149ac63e632e287d13d6730f44b44dedfdf8571 Mon Sep 17 00:00:00 2001 From: Aaron Gable Date: Wed, 27 Mar 2024 17:17:39 -0700 Subject: [PATCH] admin: gain ability to block keys by spki hash (#7397) Add a new input method flag to `admin block-key` which processes a file containing one hexadecimal-encoded SPKI hash on each line. To facilitate this, restructure the block-key subcommand's execution to more closely resemble the revoke-cert subcommand, with a parallelism flag and the ability to run many workers at the same time. Part of https://github.com/letsencrypt/boulder/issues/7267 --- cmd/admin/key.go | 117 +++++++++++++++++++++++++++++++++++++----- cmd/admin/key_test.go | 30 ++++++++++- 2 files changed, 132 insertions(+), 15 deletions(-) diff --git a/cmd/admin/key.go b/cmd/admin/key.go index 650d3b9c1..c27406918 100644 --- a/cmd/admin/key.go +++ b/cmd/admin/key.go @@ -1,16 +1,22 @@ package main import ( + "bufio" "context" + "encoding/hex" "errors" "flag" "fmt" "io" + "os" "os/user" + "sync" + "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/timestamppb" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/privatekey" sapb "github.com/letsencrypt/boulder/sa/proto" ) @@ -18,20 +24,48 @@ import ( // subcommandBlockKey encapsulates the "admin block-key" command. func (a *admin) subcommandBlockKey(ctx context.Context, args []string) error { subflags := flag.NewFlagSet("block-key", flag.ExitOnError) - privKey := subflags.String("private-key", "", "Block issuance for the pubkey corresponding to this private key") + + // General flags relevant to all key input methods. + parallelism := subflags.Uint("parallelism", 10, "Number of concurrent workers to use while blocking keys") comment := subflags.String("comment", "", "Additional context to add to database comment column") + + // Flags specifying the input method for the keys to be blocked. + privKey := subflags.String("private-key", "", "Block issuance for the pubkey corresponding to this private key") + spkiFile := subflags.String("spki-file", "", "Block issuance for all keys listed in this file as SHA256 hashes of SPKI, hex encoded, one per line") + _ = subflags.Parse(args) - if *privKey == "" { - return errors.New("the -private-key flag is required") + // This is a map of all input-selection flags to whether or not they were set + // to a non-default value. We use this to ensure that exactly one input + // selection flag was given on the command line. + setInputs := map[string]bool{ + "-private-key": *privKey != "", + "-spki-file": *spkiFile != "", + } + maps.DeleteFunc(setInputs, func(_ string, v bool) bool { return !v }) + if len(setInputs) == 0 { + return errors.New("at least one input method flag must be specified") + } else if len(setInputs) > 1 { + return fmt.Errorf("more than one input method flag specified: %v", maps.Keys(setInputs)) } - spkiHash, err := a.spkiHashFromPrivateKey(*privKey) + var spkiHashes [][]byte + var err error + switch maps.Keys(setInputs)[0] { + case "-private-key": + var spkiHash []byte + spkiHash, err = a.spkiHashFromPrivateKey(*privKey) + spkiHashes = [][]byte{spkiHash} + case "-spki-file": + spkiHashes, err = a.spkiHashesFromFile(*spkiFile) + default: + return errors.New("no recognized input method flag set (this shouldn't happen)") + } if err != nil { - return err + return fmt.Errorf("collecting serials to revoke: %w", err) } - err = a.blockSPKIHash(ctx, spkiHash, *comment) + err = a.blockSPKIHashes(ctx, spkiHashes, *comment, int(*parallelism)) if err != nil { return err } @@ -53,13 +87,75 @@ func (a *admin) spkiHashFromPrivateKey(keyFile string) ([]byte, error) { return spkiHash[:], nil } -func (a *admin) blockSPKIHash(ctx context.Context, spkiHash []byte, comment string) error { +func (a *admin) spkiHashesFromFile(filePath string) ([][]byte, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("opening spki hashes file: %w", err) + } + + var spkiHashes [][]byte + scanner := bufio.NewScanner(file) + for scanner.Scan() { + spkiHex := scanner.Text() + if spkiHex == "" { + continue + } + spkiHash, err := hex.DecodeString(spkiHex) + if err != nil { + return nil, fmt.Errorf("decoding hex spki hash %q: %w", spkiHex, err) + } + + if len(spkiHash) != 32 { + return nil, fmt.Errorf("got spki hash of unexpected length: %q (%d)", spkiHex, len(spkiHash)) + } + + spkiHashes = append(spkiHashes, spkiHash) + } + + return spkiHashes, nil +} + +func (a *admin) blockSPKIHashes(ctx context.Context, spkiHashes [][]byte, comment string, parallelism int) error { + u, err := user.Current() + if err != nil { + return fmt.Errorf("getting admin username: %w", err) + } + + wg := new(sync.WaitGroup) + work := make(chan []byte, parallelism) + for i := 0; i < parallelism; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for spkiHash := range work { + err = a.blockSPKIHash(ctx, spkiHash, u, comment) + if err != nil { + if errors.Is(err, berrors.AlreadyRevoked) { + a.log.Errf("not blocking %x: already blocked", spkiHash) + } else { + a.log.Errf("failed to block %x: %s", spkiHash, err) + } + } + } + }() + } + + for _, spkiHash := range spkiHashes { + work <- spkiHash + } + close(work) + wg.Wait() + + return nil +} + +func (a *admin) blockSPKIHash(ctx context.Context, spkiHash []byte, u *user.User, comment string) error { exists, err := a.saroc.KeyBlocked(ctx, &sapb.SPKIHash{KeyHash: spkiHash}) if err != nil { return fmt.Errorf("checking if key is already blocked: %w", err) } if exists.Exists { - return errors.New("the provided key already exists in the 'blockedKeys' table") + return berrors.AlreadyRevokedError("the provided key already exists in the 'blockedKeys' table") } stream, err := a.saroc.GetSerialsByKey(ctx, &sapb.SPKIHash{KeyHash: spkiHash}) @@ -81,11 +177,6 @@ func (a *admin) blockSPKIHash(ctx context.Context, spkiHash []byte, comment stri a.log.Infof("Found %d unexpired certificates matching the provided key", count) - u, err := user.Current() - if err != nil { - return fmt.Errorf("getting admin username: %w", err) - } - _, err = a.sac.AddBlockedKey(ctx, &sapb.AddBlockedKeyRequest{ KeyHash: spkiHash[:], Added: timestamppb.New(a.clk.Now()), diff --git a/cmd/admin/key_test.go b/cmd/admin/key_test.go index 23fe77bcc..2ad045848 100644 --- a/cmd/admin/key_test.go +++ b/cmd/admin/key_test.go @@ -5,10 +5,15 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/sha256" "crypto/x509" + "encoding/hex" "encoding/pem" "os" + "os/user" "path" + "strconv" + "strings" "testing" "time" @@ -43,6 +48,26 @@ func TestSPKIHashFromPrivateKey(t *testing.T) { test.AssertByteEquals(t, res, keyHash[:]) } +func TestSPKIHashesFromFile(t *testing.T) { + var spkiHexes []string + for i := 0; i < 10; i++ { + h := sha256.Sum256([]byte(strconv.Itoa(i))) + spkiHexes = append(spkiHexes, hex.EncodeToString(h[:])) + } + + spkiFile := path.Join(t.TempDir(), "spkis.txt") + err := os.WriteFile(spkiFile, []byte(strings.Join(spkiHexes, "\n")), os.ModeAppend) + test.AssertNotError(t, err, "writing test spki file") + + a := admin{} + + res, err := a.spkiHashesFromFile(spkiFile) + test.AssertNotError(t, err, "") + for i, spkiHash := range res { + test.AssertEquals(t, hex.EncodeToString(spkiHash), spkiHexes[i]) + } +} + // mockSARecordingBlocks is a mock which only implements the AddBlockedKey gRPC // method. type mockSARecordingBlocks struct { @@ -73,12 +98,13 @@ func TestBlockSPKIHash(t *testing.T) { test.AssertNotError(t, err, "computing test SPKI hash") a := admin{saroc: &mocks.StorageAuthorityReadOnly{}, sac: &msa, clk: fc, log: log} + u := &user.User{} // A full run should result in one request with the right fields. msa.reset() log.Clear() a.dryRun = false - err = a.blockSPKIHash(context.Background(), keyHash[:], "hello world") + err = a.blockSPKIHash(context.Background(), keyHash[:], u, "hello world") test.AssertNotError(t, err, "") test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1) test.AssertEquals(t, len(msa.blockRequests), 1) @@ -90,7 +116,7 @@ func TestBlockSPKIHash(t *testing.T) { log.Clear() a.dryRun = true a.sac = dryRunSAC{log: log} - err = a.blockSPKIHash(context.Background(), keyHash[:], "") + err = a.blockSPKIHash(context.Background(), keyHash[:], u, "") test.AssertNotError(t, err, "") test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1) test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 1)