From d05a0f6afe084e45a4fd32792334257f8f6c227b Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Fri, 9 Sep 2022 11:29:38 -0700 Subject: [PATCH] rocsp-tool: fix printing of commandline help (#6360) This refactors how rocsp-tool processes its subcommands so we're less likely to make mistakes (like under- or over- including subcommands in the help output). --- cmd/rocsp-tool/main.go | 185 +++++++++++++++++++++++++---------------- 1 file changed, 112 insertions(+), 73 deletions(-) diff --git a/cmd/rocsp-tool/main.go b/cmd/rocsp-tool/main.go index a44079232..5d429e856 100644 --- a/cmd/rocsp-tool/main.go +++ b/cmd/rocsp-tool/main.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "os" + "strings" "time" "github.com/jmhodges/clock" @@ -73,27 +74,29 @@ func main() { } } +var startFromID = flag.Int64("start-from-id", 0, "For load-from-db, the first ID in the certificateStatus table to scan") + func main2() error { configFile := flag.String("config", "", "File path to the configuration file for this service") - startFromID := flag.Int64("start-from-id", 0, "For load-from-db, the first ID in the certificateStatus table to scan") + flag.Usage = helpExit flag.Parse() - if *configFile == "" { - flag.Usage() - os.Exit(1) + if *configFile == "" || len(flag.Args()) < 1 { + helpExit() } + rand.Seed(time.Now().UnixNano()) - var c Config - err := cmd.ReadConfigFile(*configFile, &c) + var conf Config + err := cmd.ReadConfigFile(*configFile, &conf) if err != nil { return fmt.Errorf("reading JSON config file: %w", err) } - _, logger := cmd.StatsAndLogging(c.Syslog, c.ROCSPTool.DebugAddr) + _, logger := cmd.StatsAndLogging(conf.Syslog, conf.ROCSPTool.DebugAddr) defer logger.AuditPanic() clk := cmd.Clock() - redisClient, err := rocsp_config.MakeClient(&c.ROCSPTool.Redis, clk, metrics.NoopRegisterer) + redisClient, err := rocsp_config.MakeClient(&conf.ROCSPTool.Redis, clk, metrics.NoopRegisterer) if err != nil { return fmt.Errorf("making client: %w", err) } @@ -101,8 +104,8 @@ func main2() error { var db *db.WrappedMap var ocspGenerator capb.OCSPGeneratorClient var scanBatchSize int - if c.ROCSPTool.LoadFromDB != nil { - lfd := c.ROCSPTool.LoadFromDB + if conf.ROCSPTool.LoadFromDB != nil { + lfd := conf.ROCSPTool.LoadFromDB db, err = sa.InitWrappedDb(lfd.DB, nil, logger) if err != nil { return fmt.Errorf("connecting to DB: %w", err) @@ -119,10 +122,6 @@ func main2() error { scanBatchSize = lfd.Speed.ScanBatchSize } - if len(flag.Args()) < 1 { - helpExit() - } - ctx := context.Background() cl := client{ redis: redisClient, @@ -132,71 +131,111 @@ func main2() error { scanBatchSize: scanBatchSize, logger: logger, } - switch flag.Arg(0) { - case "get": - for _, serial := range flag.Args()[1:] { - resp, err := cl.redis.GetResponse(ctx, serial) - if err != nil { - return err - } - parsed, err := ocsp.ParseResponse(resp, nil) - if err != nil { - fmt.Fprintf(os.Stderr, "parsing error on %x: %s", resp, err) - continue - } else { - fmt.Printf("%s\n", helper.PrettyResponse(parsed)) - } + + for _, sc := range subCommands { + if flag.Arg(0) == sc.name { + return sc.cmd(ctx, cl, conf, flag.Args()[1:]) } - case "get-pem": - for _, serial := range flag.Args()[1:] { - resp, err := cl.redis.GetResponse(ctx, serial) - if err != nil { - return err - } - block := pem.Block{ - Bytes: resp, - Type: "OCSP RESPONSE", - } - pem.Encode(os.Stdout, &block) - } - case "store": - logger.Info(cmd.VersionString()) - err := cl.storeResponsesFromFiles(ctx, flag.Args()[1:]) - if err != nil { - return err - } - case "load-from-db": - logger.Info(cmd.VersionString()) - if c.ROCSPTool.LoadFromDB == nil { - return fmt.Errorf("config field LoadFromDB was missing") - } - err = cl.loadFromDB(ctx, c.ROCSPTool.LoadFromDB.Speed, *startFromID) - if err != nil { - return fmt.Errorf("loading OCSP responses from DB: %w", err) - } - case "scan-responses": - logger.Info(cmd.VersionString()) - results := cl.redis.ScanResponses(ctx, "*") - for r := range results { - if r.Err != nil { - cmd.FailOnError(err, "while scanning") - } - logger.Infof("%s: %s\n", r.Serial, base64.StdEncoding.EncodeToString(r.Body)) - } - default: - logger.Errf("unrecognized subcommand %q\n", flag.Arg(0)) - helpExit() } + fmt.Fprintf(os.Stderr, "unrecognized subcommand %q\n", flag.Arg(0)) + helpExit() return nil } +// subCommand represents a single subcommand. `name` is the name used to invoke it, and `help` is +// its help text. +type subCommand struct { + name string + help string + cmd func(context.Context, client, Config, []string) error +} + +var ( + Store = subCommand{"store", "for each filename on command line, read the file as an OCSP response and store it in Redis", + func(ctx context.Context, cl client, _ Config, args []string) error { + err := cl.storeResponsesFromFiles(ctx, flag.Args()[1:]) + if err != nil { + return err + } + return nil + }, + } + Get = subCommand{ + "get", + "for each serial on command line, fetch that serial's response and pretty-print it", + func(ctx context.Context, cl client, _ Config, args []string) error { + for _, serial := range flag.Args()[1:] { + resp, err := cl.redis.GetResponse(ctx, serial) + if err != nil { + return err + } + parsed, err := ocsp.ParseResponse(resp, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "parsing error on %x: %s", resp, err) + continue + } else { + fmt.Printf("%s\n", helper.PrettyResponse(parsed)) + } + } + return nil + }, + } + GetPEM = subCommand{"get-pem", "for each serial on command line, fetch that serial's response and print it PEM-encoded", + func(ctx context.Context, cl client, _ Config, args []string) error { + for _, serial := range flag.Args()[1:] { + resp, err := cl.redis.GetResponse(ctx, serial) + if err != nil { + return err + } + block := pem.Block{ + Bytes: resp, + Type: "OCSP RESPONSE", + } + pem.Encode(os.Stdout, &block) + } + return nil + }, + } + LoadFromDB = subCommand{"load-from-db", "scan the database for all OCSP entries for unexpired certificates, and store in Redis", + func(ctx context.Context, cl client, c Config, args []string) error { + if c.ROCSPTool.LoadFromDB == nil { + return fmt.Errorf("config field LoadFromDB was missing") + } + err := cl.loadFromDB(ctx, c.ROCSPTool.LoadFromDB.Speed, *startFromID) + if err != nil { + return fmt.Errorf("loading OCSP responses from DB: %w", err) + } + return nil + }, + } + ScanResponses = subCommand{"scan-responses", "scan Redis for OCSP response entries. For each entry, print the serial and base64-encoded response", + func(ctx context.Context, cl client, _ Config, args []string) error { + results := cl.redis.ScanResponses(ctx, "*") + for r := range results { + if r.Err != nil { + return r.Err + } + fmt.Printf("%s: %s\n", r.Serial, base64.StdEncoding.EncodeToString(r.Body)) + } + return nil + }, + } +) + +var subCommands = []subCommand{ + Store, Get, GetPEM, LoadFromDB, ScanResponses, +} + func helpExit() { - fmt.Fprintf(os.Stderr, "Usage: %s [store|copy-from-db|scan-metadata|scan-responses] --config path/to/config.json\n", os.Args[0]) - fmt.Fprintln(os.Stderr, " store -- for each filename on command line, read the file as an OCSP response and store it in Redis") - fmt.Fprintln(os.Stderr, " get -- for each serial on command line, fetch that serial's response and pretty-print it") - fmt.Fprintln(os.Stderr, " load-from-db -- scan the database for all OCSP entries for unexpired certificates, and store in Redis") - fmt.Fprintln(os.Stderr, " scan-metadata -- scan Redis for metadata entries. For each entry, print the serial and the age in hours") - fmt.Fprintln(os.Stderr, " scan-responses -- scan Redis for OCSP response entries. For each entry, print the serial and base64-encoded response") + var names []string + var helpStrings []string + for _, s := range subCommands { + names = append(names, s.name) + helpStrings = append(helpStrings, fmt.Sprintf(" %s -- %s", s.name, s.help)) + } + fmt.Fprintf(os.Stderr, "Usage: %s [%s] --config path/to/config.json\n", os.Args[0], strings.Join(names, "|")) + os.Stderr.Write([]byte(strings.Join(helpStrings, "\n"))) + fmt.Fprintln(os.Stderr) fmt.Fprintln(os.Stderr) flag.PrintDefaults() os.Exit(1)