diff --git a/cmd/id-exporter/main.go b/cmd/id-exporter/main.go index 1e4e6fe7f..de2c2f46b 100644 --- a/cmd/id-exporter/main.go +++ b/cmd/id-exporter/main.go @@ -24,15 +24,53 @@ type idExporter struct { grace time.Duration } -type id struct { +// resultEntry is a JSON marshalable exporter result entry. +type resultEntry struct { + // ID is exported to support marshaling to JSON. ID int64 `json:"id"` + + // Hostname is exported to support marshaling to JSON. Not all queries + // will fill this field, so it's JSON field tag marks at as + // omittable. + Hostname string `json:"hostname,omitempty"` +} + +// reverseHostname converts (reversed) names sourced from the +// registrations table to standard hostnames. +func (r *resultEntry) reverseHostname() { + r.Hostname = sa.ReverseName(r.Hostname) +} + +// idExporterResults is passed as a selectable 'holder' for the results +// of id-exporter database queries +type idExporterResults []*resultEntry + +// marshalToJSON returns JSON as bytes for all elements of the inner `id` +// slice. +func (i *idExporterResults) marshalToJSON() ([]byte, error) { + data, err := json.Marshal(i) + if err != nil { + return nil, err + } + data = append(data, '\n') + return data, nil +} + +// writeToFile writes the contents of the inner `ids` slice, as JSON, to +// a file +func (i *idExporterResults) writeToFile(outfile string) error { + data, err := i.marshalToJSON() + if err != nil { + return err + } + return ioutil.WriteFile(outfile, data, 0644) } // Find all registration IDs with unexpired certificates. -func (c idExporter) findIDs() ([]id, error) { - var idsList []id +func (c idExporter) findIDs() (idExporterResults, error) { + var holder idExporterResults _, err := c.dbMap.Select( - &idsList, + &holder, `SELECT id FROM registrations WHERE contact != 'null' AND @@ -48,18 +86,44 @@ func (c idExporter) findIDs() ([]id, error) { c.log.AuditErrf("Error finding IDs: %s", err) return nil, err } - - return idsList, nil + return holder, nil } -func (c idExporter) findIDsForDomains(domains []string) ([]id, error) { - var idsList []id +// Find all registration IDs with unexpired certificates and gather an +// example hostname. +func (c idExporter) findIDsWithExampleHostnames() (idExporterResults, error) { + var holder idExporterResults + _, err := c.dbMap.Select( + &holder, + `SELECT SQL_BIG_RESULT + cert.registrationID AS id, + name.reversedName AS hostname + FROM certificates AS cert + INNER JOIN issuedNames AS name ON name.serial = cert.serial + WHERE cert.expires >= :expireCutoff + GROUP BY cert.registrationID;`, + map[string]interface{}{ + "expireCutoff": c.clk.Now().Add(-c.grace), + }) + if err != nil { + c.log.AuditErrf("Error finding IDs and example hostnames: %s", err) + return nil, err + } + + for _, result := range holder { + result.reverseHostname() + } + return holder, nil +} + +func (c idExporter) findIDsForDomains(domains []string) (idExporterResults, error) { + var holder idExporterResults for _, domain := range domains { // Pass the same list in each time, gorp will happily just append to the slice // instead of overwriting it each time // https://github.com/go-gorp/gorp/blob/2ae7d174a4cf270240c4561092402affba25da5e/select.go#L348-L355 _, err := c.dbMap.Select( - &idsList, + &holder, `SELECT registrationID AS id FROM certificates WHERE expires >= :expireCutoff AND serial IN ( @@ -79,24 +143,7 @@ func (c idExporter) findIDsForDomains(domains []string) ([]id, error) { } } - return idsList, nil -} - -// The `writeIDs` function produces a file containing JSON serialized -// contact objects -func writeIDs(idsList []id, outfile string) error { - data, err := json.Marshal(idsList) - if err != nil { - return err - } - data = append(data, '\n') - - if outfile != "" { - return ioutil.WriteFile(outfile, data, 0644) - } - - fmt.Printf("%s", data) - return nil + return holder, nil } const usageIntro = ` @@ -117,11 +164,18 @@ mailing is underway, ensuring we use the correct address if a user has updated their contact information between the time of export and the time of notification. -The ID exporter's output will be JSON of the form: +By default, the ID exporter's output will be JSON of the form: [ - { "id": 1 }, - ... - { "id": n } + { "id": 1 }, + ... + { "id": n } + ] + +Operations that return a hostname will be JSON of the form: + [ + { "id": 1, "hostname": "example-1.com" }, + ... + { "id": n, "hostname": "example-n.com" } ] Examples: @@ -143,6 +197,7 @@ func main() { outFile := flag.String("outfile", "", "File to write contacts to (defaults to stdout).") grace := flag.Duration("grace", 2*24*time.Hour, "Include contacts with certificates that expired in < grace ago") domainsFile := flag.String("domains", "", "If provided only output contacts for certificates that contain at least one of the domains in the provided file. Provided file should contain one domain per line") + withExampleHostnames := flag.Bool("with-example-hostnames", false, "In addition to IDs, gather an example domain name that corresponds to that ID") type config struct { ContactExporter struct { DB cmd.DBConfig @@ -189,17 +244,27 @@ func main() { grace: *grace, } - var ids []id + var results idExporterResults if *domainsFile != "" { + // Gather IDs for the domains listed in the `domainsFile`. df, err := ioutil.ReadFile(*domainsFile) cmd.FailOnError(err, fmt.Sprintf("Could not read domains file %q", *domainsFile)) - ids, err = exporter.findIDsForDomains(strings.Split(string(df), "\n")) - cmd.FailOnError(err, "Could not find IDs") + + results, err = exporter.findIDsForDomains(strings.Split(string(df), "\n")) + cmd.FailOnError(err, "Could not find IDs for domains") + + } else if *withExampleHostnames { + // Gather subscriber IDs and hostnames. + results, err = exporter.findIDsWithExampleHostnames() + cmd.FailOnError(err, "Could not find IDs with hostnames") + } else { - ids, err = exporter.findIDs() + // Gather only subscriber IDs. + results, err = exporter.findIDs() cmd.FailOnError(err, "Could not find IDs") } - err = writeIDs(ids, *outFile) - cmd.FailOnError(err, fmt.Sprintf("Could not write IDs to outfile %q", *outFile)) + // Write results to file. + err = results.writeToFile(*outFile) + cmd.FailOnError(err, fmt.Sprintf("Could not write result to outfile %q", *outFile)) } diff --git a/cmd/id-exporter/main_test.go b/cmd/id-exporter/main_test.go index dcb58cd68..3dbba5ac2 100644 --- a/cmd/id-exporter/main_test.go +++ b/cmd/id-exporter/main_test.go @@ -49,9 +49,9 @@ func TestFindIDs(t *testing.T) { // Run findIDs - since no certificates have been added corresponding to // the above registrations, no IDs should be found. - ids, err := testCtx.c.findIDs() + results, err := testCtx.c.findIDs() test.AssertNotError(t, err, "findIDs() produced error") - test.AssertEquals(t, len(ids), 0) + test.AssertEquals(t, len(results), 0) // Now add some certificates testCtx.addCertificates(t) @@ -61,24 +61,96 @@ func TestFindIDs(t *testing.T) { // *not* be present since their certificate has already expired. Unlike // previous versions of this test RegD is not filtered out for having a `tel:` // contact field anymore - this is the duty of the notify-mailer. - ids, err = testCtx.c.findIDs() + results, err = testCtx.c.findIDs() test.AssertNotError(t, err, "findIDs() produced error") - test.AssertEquals(t, len(ids), 3) - test.AssertEquals(t, ids[0].ID, regA.ID) - test.AssertEquals(t, ids[1].ID, regC.ID) - test.AssertEquals(t, ids[2].ID, regD.ID) + test.AssertEquals(t, len(results), 3) + for _, entry := range results { + switch entry.ID { + case regA.ID: + case regC.ID: + case regD.ID: + default: + t.Errorf("ID: %d not expected", entry.ID) + } + } // Allow a 1 year grace period testCtx.c.grace = 360 * 24 * time.Hour - ids, err = testCtx.c.findIDs() + results, err = testCtx.c.findIDs() test.AssertNotError(t, err, "findIDs() produced error") // Now all four registration should be returned, including RegB since its // certificate expired within the grace period - test.AssertEquals(t, len(ids), 4) - test.AssertEquals(t, ids[0].ID, regA.ID) - test.AssertEquals(t, ids[1].ID, regB.ID) - test.AssertEquals(t, ids[2].ID, regC.ID) - test.AssertEquals(t, ids[3].ID, regD.ID) + for _, entry := range results { + switch entry.ID { + case regA.ID: + case regB.ID: + case regC.ID: + case regD.ID: + default: + t.Errorf("ID: %d not expected", entry.ID) + } + } +} + +func TestFindIDsWithExampleHostnames(t *testing.T) { + testCtx := setup(t) + defer testCtx.cleanUp() + + // Add some test registrations + testCtx.addRegistrations(t) + + // Run findIDsWithExampleHostnames - since no certificates have been + // added corresponding to the above registrations, no IDs should be + // found. + results, err := testCtx.c.findIDsWithExampleHostnames() + test.AssertNotError(t, err, "findIDs() produced error") + test.AssertEquals(t, len(results), 0) + + // Now add some certificates + testCtx.addCertificates(t) + + // Run findIDsWithExampleHostnames - since there are three + // registrations with unexpired certs we should get exactly three + // IDs back: RegA, RegC and RegD. RegB should *not* be present since + // their certificate has already expired. + results, err = testCtx.c.findIDsWithExampleHostnames() + test.AssertNotError(t, err, "findIDs() produced error") + test.AssertEquals(t, len(results), 3) + for _, entry := range results { + switch entry.ID { + case regA.ID: + test.AssertEquals(t, entry.Hostname, "example-a.com") + case regC.ID: + test.AssertEquals(t, entry.Hostname, "example-c.com") + case regD.ID: + test.AssertEquals(t, entry.Hostname, "example-d.com") + default: + t.Errorf("ID: %d not expected", entry.ID) + } + } + + // Allow a 1 year grace period + testCtx.c.grace = 360 * 24 * time.Hour + results, err = testCtx.c.findIDsWithExampleHostnames() + test.AssertNotError(t, err, "findIDs() produced error") + + // Now all four registrations should be returned, including RegB + // since it expired within the grace period + test.AssertEquals(t, len(results), 4) + for _, entry := range results { + switch entry.ID { + case regA.ID: + test.AssertEquals(t, entry.Hostname, "example-a.com") + case regB.ID: + test.AssertEquals(t, entry.Hostname, "example-b.com") + case regC.ID: + test.AssertEquals(t, entry.Hostname, "example-c.com") + case regD.ID: + test.AssertEquals(t, entry.Hostname, "example-d.com") + default: + t.Errorf("ID: %d not expected", entry.ID) + } + } } func TestFindIDsForDomains(t *testing.T) { @@ -90,49 +162,37 @@ func TestFindIDsForDomains(t *testing.T) { // Run findIDsForDomains - since no certificates have been added corresponding to // the above registrations, no IDs should be found. - ids, err := testCtx.c.findIDsForDomains([]string{"example-a.com", "example-b.com", "example-c.com", "example-d.com"}) + results, err := testCtx.c.findIDsForDomains([]string{"example-a.com", "example-b.com", "example-c.com", "example-d.com"}) test.AssertNotError(t, err, "findIDs() produced error") - test.AssertEquals(t, len(ids), 0) + test.AssertEquals(t, len(results), 0) // Now add some certificates testCtx.addCertificates(t) - ids, err = testCtx.c.findIDsForDomains([]string{"example-a.com", "example-b.com", "example-c.com", "example-d.com"}) + results, err = testCtx.c.findIDsForDomains([]string{"example-a.com", "example-b.com", "example-c.com", "example-d.com"}) test.AssertNotError(t, err, "findIDsForDomains() failed") - test.AssertEquals(t, len(ids), 3) - test.AssertEquals(t, ids[0].ID, regA.ID) - test.AssertEquals(t, ids[1].ID, regC.ID) - test.AssertEquals(t, ids[2].ID, regD.ID) -} - -func exampleIds() []id { - return []id{ - { - ID: 1, - }, - { - ID: 2, - }, - { - ID: 3, - }, + test.AssertEquals(t, len(results), 3) + for _, entry := range results { + switch entry.ID { + case regA.ID: + case regC.ID: + case regD.ID: + default: + t.Errorf("ID: %d not expected", entry.ID) + } } } -func TestWriteOutput(t *testing.T) { +func TestWriteToFile(t *testing.T) { expected := `[{"id":1},{"id":2},{"id":3}]` - - ids := exampleIds() + mockResults := idExporterResults{{ID: 1}, {ID: 2}, {ID: 3}} dir := os.TempDir() + f, err := ioutil.TempFile(dir, "ids_test") test.AssertNotError(t, err, "ioutil.TempFile produced an error") - // Writing the ids with no outFile should print to stdout - err = writeIDs(ids, "") - test.AssertNotError(t, err, "writeIDs with no outfile produced error") - - // Writing the ids to an outFile should produce the correct results - err = writeIDs(ids, f.Name()) + // Writing the result to an outFile should produce the correct results + err = mockResults.writeToFile(f.Name()) test.AssertNotError(t, err, fmt.Sprintf("writeIDs produced an error writing to %s", f.Name())) contents, err := ioutil.ReadFile(f.Name())