From 05e631593e271d95f4d6f62a6ff80f9d029191bb Mon Sep 17 00:00:00 2001 From: Samantha Frank Date: Mon, 7 Jul 2025 17:01:05 -0400 Subject: [PATCH] ratelimits: Supporting additions for admin tooling (#8279) - Export `ValidateLimit()` for use in the admin tool. - Add utility functions `DumpOverrides()` and `LoadOverridesByBucketKey()` to dump/load overrides to/from a YAML file. - Export `Limit` and several of its fields to support calls to `LoadOverridesByBucketKey()` and `ValidateLimit()`, and to return results from `DumpOverrides()`. - Add `BuildBucketKey()`, which builds and validates bucket keys based on the limit name and provided components. - Also add a `MarshalYAML()` method to `config.Duration`. Part of https://github.com/letsencrypt/boulder/issues/8165 --- config/duration.go | 5 + identifier/identifier.go | 18 ++ ratelimits/gcra.go | 6 +- ratelimits/gcra_test.go | 4 +- ratelimits/limit.go | 266 +++++++++++----- ratelimits/limit_test.go | 289 +++++++++++++++--- ratelimits/limiter.go | 28 +- ratelimits/limiter_test.go | 52 ++-- ratelimits/names.go | 110 ++++++- ratelimits/names_test.go | 202 ++++++++++++ .../testdata/busted_override_burst_0.yml | 2 +- ratelimits/transaction.go | 41 ++- ratelimits/transaction_test.go | 6 +- ratelimits/utilities.go | 137 ++++++--- 14 files changed, 928 insertions(+), 238 deletions(-) diff --git a/config/duration.go b/config/duration.go index 90cb2277d..beaf3f208 100644 --- a/config/duration.go +++ b/config/duration.go @@ -67,3 +67,8 @@ func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { d.Duration = dur return nil } + +// MarshalYAML returns the string form of the duration, as a string. +func (d Duration) MarshalYAML() (any, error) { + return d.Duration.String(), nil +} diff --git a/identifier/identifier.go b/identifier/identifier.go index 4b161b862..9a6bb96bf 100644 --- a/identifier/identifier.go +++ b/identifier/identifier.go @@ -122,6 +122,24 @@ func NewIP(ip netip.Addr) ACMEIdentifier { } } +// FromString converts a string to an ACMEIdentifier. +func FromString(identStr string) ACMEIdentifier { + ip, err := netip.ParseAddr(identStr) + if err == nil { + return NewIP(ip) + } + return NewDNS(identStr) +} + +// FromStringSlice converts a slice of strings to a slice of ACMEIdentifier. +func FromStringSlice(identStrs []string) ACMEIdentifiers { + var idents ACMEIdentifiers + for _, identStr := range identStrs { + idents = append(idents, FromString(identStr)) + } + return idents +} + // fromX509 extracts the Subject Alternative Names from a certificate or CSR's fields, and // returns a slice of ACMEIdentifiers. func fromX509(commonName string, dnsNames []string, ipAddresses []net.IP) ACMEIdentifiers { diff --git a/ratelimits/gcra.go b/ratelimits/gcra.go index 24ae21859..5a6ff27b8 100644 --- a/ratelimits/gcra.go +++ b/ratelimits/gcra.go @@ -10,7 +10,7 @@ import ( // returns a Decision struct with the result of the decision and the updated // TAT. The cost must be 0 or greater and <= the burst capacity of the limit. func maybeSpend(clk clock.Clock, txn Transaction, tat time.Time) *Decision { - if txn.cost < 0 || txn.cost > txn.limit.burst { + if txn.cost < 0 || txn.cost > txn.limit.Burst { // The condition above is the union of the conditions checked in Check // and Spend methods of Limiter. If this panic is reached, it means that // the caller has introduced a bug. @@ -67,7 +67,7 @@ func maybeSpend(clk clock.Clock, txn Transaction, tat time.Time) *Decision { // or greater. A cost will only be refunded up to the burst capacity of the // limit. A partial refund is still considered successful. func maybeRefund(clk clock.Clock, txn Transaction, tat time.Time) *Decision { - if txn.cost < 0 || txn.cost > txn.limit.burst { + if txn.cost < 0 || txn.cost > txn.limit.Burst { // The condition above is checked in the Refund method of Limiter. If // this panic is reached, it means that the caller has introduced a bug. panic("invalid cost for maybeRefund") @@ -80,7 +80,7 @@ func maybeRefund(clk clock.Clock, txn Transaction, tat time.Time) *Decision { // The TAT is in the past, therefore the bucket is full. return &Decision{ allowed: false, - remaining: txn.limit.burst, + remaining: txn.limit.Burst, retryIn: time.Duration(0), resetIn: time.Duration(0), newTAT: tat, diff --git a/ratelimits/gcra_test.go b/ratelimits/gcra_test.go index 7f9fb2ca3..7b8cf2bc4 100644 --- a/ratelimits/gcra_test.go +++ b/ratelimits/gcra_test.go @@ -12,7 +12,7 @@ import ( func TestDecide(t *testing.T) { clk := clock.NewFake() - limit := &limit{burst: 10, count: 1, period: config.Duration{Duration: time.Second}} + limit := &Limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}} limit.precompute() // Begin by using 1 of our 10 requests. @@ -139,7 +139,7 @@ func TestDecide(t *testing.T) { func TestMaybeRefund(t *testing.T) { clk := clock.NewFake() - limit := &limit{burst: 10, count: 1, period: config.Duration{Duration: time.Second}} + limit := &Limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}} limit.precompute() // Begin by using 1 of our 10 requests. diff --git a/ratelimits/limit.go b/ratelimits/limit.go index 5919844e0..f87e6e3ee 100644 --- a/ratelimits/limit.go +++ b/ratelimits/limit.go @@ -1,10 +1,13 @@ package ratelimits import ( + "encoding/csv" "errors" "fmt" "net/netip" "os" + "sort" + "strconv" "strings" "github.com/letsencrypt/boulder/config" @@ -38,26 +41,32 @@ type LimitConfig struct { type LimitConfigs map[string]*LimitConfig -// limit defines the configuration for a rate limit or a rate limit override. +// Limit defines the configuration for a rate limit or a rate limit override. // -// The zero value of this struct is invalid, because some of the fields must -// be greater than zero. -type limit struct { - // burst specifies maximum concurrent allowed requests at any given time. It +// The zero value of this struct is invalid, because some of the fields must be +// greater than zero. It and several of its fields are exported to support admin +// tooling used during the migration from overrides.yaml to the overrides +// database table. +type Limit struct { + // Burst specifies maximum concurrent allowed requests at any given time. It // must be greater than zero. - burst int64 + Burst int64 - // count is the number of requests allowed per period. It must be greater + // Count is the number of requests allowed per period. It must be greater // than zero. - count int64 + Count int64 - // period is the duration of time in which the count (of requests) is + // Period is the duration of time in which the count (of requests) is // allowed. It must be greater than zero. - period config.Duration + Period config.Duration - // name is the name of the limit. It must be one of the Name enums defined + // Name is the name of the limit. It must be one of the Name enums defined // in this package. - name Name + Name Name + + // Comment is an optional field that can be used to provide additional + // context for an override. It is not used for default limits. + Comment string // emissionInterval is the interval, in nanoseconds, at which tokens are // added to a bucket (period / count). This is also the steady-state rate at @@ -76,25 +85,25 @@ type limit struct { } // precompute calculates the emissionInterval and burstOffset for the limit. -func (l *limit) precompute() { - l.emissionInterval = l.period.Nanoseconds() / l.count - l.burstOffset = l.emissionInterval * l.burst +func (l *Limit) precompute() { + l.emissionInterval = l.Period.Nanoseconds() / l.Count + l.burstOffset = l.emissionInterval * l.Burst } -func validateLimit(l *limit) error { - if l.burst <= 0 { - return fmt.Errorf("invalid burst '%d', must be > 0", l.burst) +func ValidateLimit(l *Limit) error { + if l.Burst <= 0 { + return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst) } - if l.count <= 0 { - return fmt.Errorf("invalid count '%d', must be > 0", l.count) + if l.Count <= 0 { + return fmt.Errorf("invalid count '%d', must be > 0", l.Count) } - if l.period.Duration <= 0 { - return fmt.Errorf("invalid period '%s', must be > 0", l.period) + if l.Period.Duration <= 0 { + return fmt.Errorf("invalid period '%s', must be > 0", l.Period) } return nil } -type limits map[string]*limit +type Limits map[string]*Limit // loadDefaults marshals the defaults YAML file at path into a map of limits. func loadDefaults(path string) (LimitConfigs, error) { @@ -149,9 +158,9 @@ func parseOverrideNameId(key string) (Name, string, error) { return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key) } - name, ok := stringToName[nameStr] + name, ok := StringToName[nameStr] if !ok { - return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, limitNames) + return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, LimitNames) } id := nameAndId[1] if id == "" { @@ -160,37 +169,52 @@ func parseOverrideNameId(key string) (Name, string, error) { return name, id, nil } +// parseOverrideNameEnumId is like parseOverrideNameId, but it expects the +// key to be formatted as 'name:id', where 'name' is a Name enum string and 'id' +// is a string identifier. It returns an error if either part is missing or invalid. +func parseOverrideNameEnumId(key string) (Name, string, error) { + if !strings.Contains(key, ":") { + // Avoids a potential panic in strings.SplitN below. + return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key) + } + nameStrAndId := strings.SplitN(key, ":", 2) + if len(nameStrAndId) != 2 { + return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key) + } + + nameInt, err := strconv.Atoi(nameStrAndId[0]) + if err != nil { + return Unknown, "", fmt.Errorf("invalid name %q in override limit %q, must be an integer", nameStrAndId[0], key) + } + name := Name(nameInt) + if !name.isValid() { + return Unknown, "", fmt.Errorf("invalid name %q in override limit %q, must be one of %v", nameStrAndId[0], key, LimitNames) + + } + id := nameStrAndId[1] + if id == "" { + return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key) + } + return name, id, nil +} + // parseOverrideLimits validates a YAML list of override limits. It must be // formatted as a list of maps, where each map has a single key representing the // limit name and a value that is a map containing the limit fields and an // additional 'ids' field that is a list of ids that this override applies to. -func parseOverrideLimits(newOverridesYAML overridesYAML) (limits, error) { - parsed := make(limits) +func parseOverrideLimits(newOverridesYAML overridesYAML) (Limits, error) { + parsed := make(Limits) for _, ov := range newOverridesYAML { for k, v := range ov { - name, ok := stringToName[k] + name, ok := StringToName[k] if !ok { - return nil, fmt.Errorf("unrecognized name %q in override limit, must be one of %v", k, limitNames) - } - - lim := &limit{ - burst: v.Burst, - count: v.Count, - period: v.Period, - name: name, - isOverride: true, - } - lim.precompute() - - err := validateLimit(lim) - if err != nil { - return nil, fmt.Errorf("validating override limit %q: %w", k, err) + return nil, fmt.Errorf("unrecognized name %q in override limit, must be one of %v", k, LimitNames) } for _, entry := range v.Ids { id := entry.Id - err = validateIdForName(name, id) + err := validateIdForName(name, id) if err != nil { return nil, fmt.Errorf( "validating name %s and id %q for override limit %q: %w", name, id, k, err) @@ -204,7 +228,7 @@ func parseOverrideLimits(newOverridesYAML overridesYAML) (limits, error) { // (IPv6) prefixes in CIDR notation. ip, err := netip.ParseAddr(id) if err == nil { - prefix, err := coveringPrefix(ip) + prefix, err := coveringIPPrefix(name, ip) if err != nil { return nil, fmt.Errorf( "computing prefix for IP address %q: %w", id, err) @@ -214,16 +238,22 @@ func parseOverrideLimits(newOverridesYAML overridesYAML) (limits, error) { case CertificatesPerFQDNSet: // Compute the hash of a comma-separated list of identifier // values. - var idents identifier.ACMEIdentifiers - for _, value := range strings.Split(id, ",") { - ip, err := netip.ParseAddr(value) - if err == nil { - idents = append(idents, identifier.NewIP(ip)) - } else { - idents = append(idents, identifier.NewDNS(value)) - } - } - id = fmt.Sprintf("%x", core.HashIdentifiers(idents)) + id = fmt.Sprintf("%x", core.HashIdentifiers(identifier.FromStringSlice(strings.Split(id, ",")))) + } + + lim := &Limit{ + Burst: v.Burst, + Count: v.Count, + Period: v.Period, + Name: name, + Comment: entry.Comment, + isOverride: true, + } + lim.precompute() + + err = ValidateLimit(lim) + if err != nil { + return nil, fmt.Errorf("validating override limit %q: %w", k, err) } parsed[joinWithColon(name.EnumString(), id)] = lim @@ -234,23 +264,23 @@ func parseOverrideLimits(newOverridesYAML overridesYAML) (limits, error) { } // parseDefaultLimits validates a map of default limits and rekeys it by 'Name'. -func parseDefaultLimits(newDefaultLimits LimitConfigs) (limits, error) { - parsed := make(limits) +func parseDefaultLimits(newDefaultLimits LimitConfigs) (Limits, error) { + parsed := make(Limits) for k, v := range newDefaultLimits { - name, ok := stringToName[k] + name, ok := StringToName[k] if !ok { - return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames) + return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, LimitNames) } - lim := &limit{ - burst: v.Burst, - count: v.Count, - period: v.Period, - name: name, + lim := &Limit{ + Burst: v.Burst, + Count: v.Count, + Period: v.Period, + Name: name, } - err := validateLimit(lim) + err := ValidateLimit(lim) if err != nil { return nil, fmt.Errorf("parsing default limit %q: %w", k, err) } @@ -263,10 +293,10 @@ func parseDefaultLimits(newDefaultLimits LimitConfigs) (limits, error) { type limitRegistry struct { // defaults stores default limits by 'name'. - defaults limits + defaults Limits // overrides stores override limits by 'name:id'. - overrides limits + overrides Limits } func newLimitRegistryFromFiles(defaults, overrides string) (*limitRegistry, error) { @@ -308,7 +338,7 @@ func newLimitRegistry(defaults LimitConfigs, overrides overridesYAML) (*limitReg // required, bucketKey is optional. If bucketkey is empty, the default for the // limit specified by name is returned. If no default limit exists for the // specified name, errLimitDisabled is returned. -func (l *limitRegistry) getLimit(name Name, bucketKey string) (*limit, error) { +func (l *limitRegistry) getLimit(name Name, bucketKey string) (*Limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit // Name enums defined in this package. @@ -327,3 +357,103 @@ func (l *limitRegistry) getLimit(name Name, bucketKey string) (*limit, error) { } return nil, errLimitDisabled } + +// LoadOverridesByBucketKey loads the overrides YAML at the supplied path, +// parses it with the existing helpers, and returns the resulting limits map +// keyed by ":". This function is exported to support admin tooling +// used during the migration from overrides.yaml to the overrides database +// table. +func LoadOverridesByBucketKey(path string) (Limits, error) { + ovs, err := loadOverrides(path) + if err != nil { + return nil, err + } + return parseOverrideLimits(ovs) +} + +// DumpOverrides writes the provided overrides to CSV at the supplied path. Each +// override is written as a single row, one per ID. Rows are sorted in the +// following order: +// - Name (ascending) +// - Count (descending) +// - Burst (descending) +// - Period (ascending) +// - Comment (ascending) +// - ID (ascending) +// +// This function supports admin tooling that routinely exports the overrides +// table for investigation or auditing. +func DumpOverrides(path string, overrides Limits) error { + type row struct { + name string + id string + count int64 + burst int64 + period string + comment string + } + + var rows []row + for bucketKey, limit := range overrides { + name, id, err := parseOverrideNameEnumId(bucketKey) + if err != nil { + return err + } + + rows = append(rows, row{ + name: name.String(), + id: id, + count: limit.Count, + burst: limit.Burst, + period: limit.Period.Duration.String(), + comment: limit.Comment, + }) + } + + sort.Slice(rows, func(i, j int) bool { + // Sort by limit name in ascending order. + if rows[i].name != rows[j].name { + return rows[i].name < rows[j].name + } + // Sort by count in descending order (higher counts first). + if rows[i].count != rows[j].count { + return rows[i].count > rows[j].count + } + // Sort by burst in descending order (higher bursts first). + if rows[i].burst != rows[j].burst { + return rows[i].burst > rows[j].burst + } + // Sort by period in ascending order (shorter durations first). + if rows[i].period != rows[j].period { + return rows[i].period < rows[j].period + } + // Sort by comment in ascending order. + if rows[i].comment != rows[j].comment { + return rows[i].comment < rows[j].comment + } + // Sort by ID in ascending order. + return rows[i].id < rows[j].id + }) + + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + w := csv.NewWriter(f) + err = w.Write([]string{"name", "id", "count", "burst", "period", "comment"}) + if err != nil { + return err + } + + for _, r := range rows { + err := w.Write([]string{r.name, r.id, strconv.FormatInt(r.count, 10), strconv.FormatInt(r.burst, 10), r.period, r.comment}) + if err != nil { + return err + } + } + w.Flush() + + return w.Error() +} diff --git a/ratelimits/limit_test.go b/ratelimits/limit_test.go index 593c811aa..fb20ff08f 100644 --- a/ratelimits/limit_test.go +++ b/ratelimits/limit_test.go @@ -3,6 +3,8 @@ package ratelimits import ( "net/netip" "os" + "path/filepath" + "strings" "testing" "time" @@ -15,7 +17,7 @@ import ( // parseDefaultLimits to handle a YAML file. // // TODO(#7901): Update the tests to test these functions individually. -func loadAndParseDefaultLimits(path string) (limits, error) { +func loadAndParseDefaultLimits(path string) (Limits, error) { fromFile, err := loadDefaults(path) if err != nil { return nil, err @@ -28,7 +30,7 @@ func loadAndParseDefaultLimits(path string) (limits, error) { // parseOverrideLimits to handle a YAML file. // // TODO(#7901): Update the tests to test these functions individually. -func loadAndParseOverrideLimits(path string) (limits, error) { +func loadAndParseOverrideLimits(path string) (Limits, error) { fromFile, err := loadOverrides(path) if err != nil { return nil, err @@ -69,17 +71,79 @@ func TestParseOverrideNameId(t *testing.T) { test.AssertError(t, err, "invalid enum") } +func TestParseOverrideNameEnumId(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantLimit Name + wantId string + expectError bool + }{ + { + name: "valid IPv4 address", + input: NewRegistrationsPerIPAddress.EnumString() + ":10.0.0.1", + wantLimit: NewRegistrationsPerIPAddress, + wantId: "10.0.0.1", + expectError: false, + }, + { + name: "valid IPv6 address range", + input: NewRegistrationsPerIPv6Range.EnumString() + ":2001:0db8:0000::/48", + wantLimit: NewRegistrationsPerIPv6Range, + wantId: "2001:0db8:0000::/48", + expectError: false, + }, + { + name: "missing colon", + input: NewRegistrationsPerIPAddress.EnumString() + "10.0.0.1", + expectError: true, + }, + { + name: "empty string", + input: "", + expectError: true, + }, + { + name: "only a colon", + input: NewRegistrationsPerIPAddress.EnumString() + ":", + expectError: true, + }, + { + name: "invalid enum", + input: "lol:noexist", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + limit, id, err := parseOverrideNameEnumId(tc.input) + if tc.expectError { + if err == nil { + t.Errorf("expected error for input %q, but got none", tc.input) + } + } else { + test.AssertNotError(t, err, tc.name) + test.AssertEquals(t, limit, tc.wantLimit) + test.AssertEquals(t, id, tc.wantId) + } + }) + } +} + func TestValidateLimit(t *testing.T) { - err := validateLimit(&limit{burst: 1, count: 1, period: config.Duration{Duration: time.Second}}) + err := ValidateLimit(&Limit{Burst: 1, Count: 1, Period: config.Duration{Duration: time.Second}}) test.AssertNotError(t, err, "valid limit") // All of the following are invalid. - for _, l := range []*limit{ - {burst: 0, count: 1, period: config.Duration{Duration: time.Second}}, - {burst: 1, count: 0, period: config.Duration{Duration: time.Second}}, - {burst: 1, count: 1, period: config.Duration{Duration: 0}}, + for _, l := range []*Limit{ + {Burst: 0, Count: 1, Period: config.Duration{Duration: time.Second}}, + {Burst: 1, Count: 0, Period: config.Duration{Duration: time.Second}}, + {Burst: 1, Count: 1, Period: config.Duration{Duration: 0}}, } { - err = validateLimit(l) + err = ValidateLimit(l) test.AssertError(t, err, "limit should be invalid") } } @@ -89,29 +153,29 @@ func TestLoadAndParseOverrideLimits(t *testing.T) { l, err := loadAndParseOverrideLimits("testdata/working_override.yml") test.AssertNotError(t, err, "valid single override limit") expectKey := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "64.112.117.1") - test.AssertEquals(t, l[expectKey].burst, int64(40)) - test.AssertEquals(t, l[expectKey].count, int64(40)) - test.AssertEquals(t, l[expectKey].period.Duration, time.Second) + test.AssertEquals(t, l[expectKey].Burst, int64(40)) + test.AssertEquals(t, l[expectKey].Count, int64(40)) + test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) // Load single valid override limit with a 'domainOrCIDR' Id. l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domainorcidr.yml") test.AssertNotError(t, err, "valid single override limit with Id of regId:domainOrCIDR") expectKey = joinWithColon(CertificatesPerDomain.EnumString(), "example.com") - test.AssertEquals(t, l[expectKey].burst, int64(40)) - test.AssertEquals(t, l[expectKey].count, int64(40)) - test.AssertEquals(t, l[expectKey].period.Duration, time.Second) + test.AssertEquals(t, l[expectKey].Burst, int64(40)) + test.AssertEquals(t, l[expectKey].Count, int64(40)) + test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) // Load multiple valid override limits with 'regId' Ids. l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml") test.AssertNotError(t, err, "multiple valid override limits") expectKey1 := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "64.112.117.1") - test.AssertEquals(t, l[expectKey1].burst, int64(40)) - test.AssertEquals(t, l[expectKey1].count, int64(40)) - test.AssertEquals(t, l[expectKey1].period.Duration, time.Second) + test.AssertEquals(t, l[expectKey1].Burst, int64(40)) + test.AssertEquals(t, l[expectKey1].Count, int64(40)) + test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second) expectKey2 := joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), "2602:80a:6000::/48") - test.AssertEquals(t, l[expectKey2].burst, int64(50)) - test.AssertEquals(t, l[expectKey2].count, int64(50)) - test.AssertEquals(t, l[expectKey2].period.Duration, time.Second*2) + test.AssertEquals(t, l[expectKey2].Burst, int64(50)) + test.AssertEquals(t, l[expectKey2].Count, int64(50)) + test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2) // Load multiple valid override limits with 'fqdnSet' Ids, as follows: // - CertificatesPerFQDNSet:example.com @@ -128,18 +192,18 @@ func TestLoadAndParseOverrideLimits(t *testing.T) { l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml") test.AssertNotError(t, err, "multiple valid override limits with 'fqdnSet' Ids") - test.AssertEquals(t, l[entryKey1].burst, int64(40)) - test.AssertEquals(t, l[entryKey1].count, int64(40)) - test.AssertEquals(t, l[entryKey1].period.Duration, time.Second) - test.AssertEquals(t, l[entryKey2].burst, int64(50)) - test.AssertEquals(t, l[entryKey2].count, int64(50)) - test.AssertEquals(t, l[entryKey2].period.Duration, time.Second*2) - test.AssertEquals(t, l[entryKey3].burst, int64(60)) - test.AssertEquals(t, l[entryKey3].count, int64(60)) - test.AssertEquals(t, l[entryKey3].period.Duration, time.Second*3) - test.AssertEquals(t, l[entryKey4].burst, int64(60)) - test.AssertEquals(t, l[entryKey4].count, int64(60)) - test.AssertEquals(t, l[entryKey4].period.Duration, time.Second*4) + test.AssertEquals(t, l[entryKey1].Burst, int64(40)) + test.AssertEquals(t, l[entryKey1].Count, int64(40)) + test.AssertEquals(t, l[entryKey1].Period.Duration, time.Second) + test.AssertEquals(t, l[entryKey2].Burst, int64(50)) + test.AssertEquals(t, l[entryKey2].Count, int64(50)) + test.AssertEquals(t, l[entryKey2].Period.Duration, time.Second*2) + test.AssertEquals(t, l[entryKey3].Burst, int64(60)) + test.AssertEquals(t, l[entryKey3].Count, int64(60)) + test.AssertEquals(t, l[entryKey3].Period.Duration, time.Second*3) + test.AssertEquals(t, l[entryKey4].Burst, int64(60)) + test.AssertEquals(t, l[entryKey4].Count, int64(60)) + test.AssertEquals(t, l[entryKey4].Period.Duration, time.Second*4) // Path is empty string. _, err = loadAndParseOverrideLimits("") @@ -186,19 +250,19 @@ func TestLoadAndParseDefaultLimits(t *testing.T) { // Load a single valid default limit. l, err := loadAndParseDefaultLimits("testdata/working_default.yml") test.AssertNotError(t, err, "valid single default limit") - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].burst, int64(20)) - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].count, int64(20)) - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) // Load multiple valid default limits. l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml") test.AssertNotError(t, err, "multiple valid default limits") - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].burst, int64(20)) - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].count, int64(20)) - test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].period.Duration, time.Second) - test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].burst, int64(30)) - test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].count, int64(30)) - test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].period.Duration, time.Second*2) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Burst, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Count, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Period.Duration, time.Second*2) // Path is empty string. _, err = loadAndParseDefaultLimits("") @@ -230,3 +294,146 @@ func TestLoadAndParseDefaultLimits(t *testing.T) { test.AssertError(t, err, "multiple default limits, one is bad") test.Assert(t, !os.IsNotExist(err), "test file should exist") } + +func TestLoadAndDumpOverrides(t *testing.T) { + t.Parallel() + + input := ` +- CertificatesPerDomain: + burst: 5000 + count: 5000 + period: 168h0m0s + ids: + - id: example.com + comment: IN-10057 + - id: example.net + comment: IN-10057 +- CertificatesPerDomain: + burst: 300 + count: 300 + period: 168h0m0s + ids: + - id: example.org + comment: IN-10057 +- CertificatesPerDomainPerAccount: + burst: 12000 + count: 12000 + period: 168h0m0s + ids: + - id: "123456789" + comment: Affluent (IN-8322) +- CertificatesPerDomainPerAccount: + burst: 6000 + count: 6000 + period: 168h0m0s + ids: + - id: "543219876" + comment: Affluent (IN-8322) + - id: "987654321" + comment: Affluent (IN-8322) +- CertificatesPerFQDNSet: + burst: 50 + count: 50 + period: 168h0m0s + ids: + - id: example.co.uk,example.cn + comment: IN-6843 +- CertificatesPerFQDNSet: + burst: 24 + count: 24 + period: 168h0m0s + ids: + - id: example.org,example.com,example.net + comment: IN-6006 +- FailedAuthorizationsPerDomainPerAccount: + burst: 250 + count: 250 + period: 1h0m0s + ids: + - id: "123456789" + comment: Digital Lake (IN-6736) +- FailedAuthorizationsPerDomainPerAccount: + burst: 50 + count: 50 + period: 1h0m0s + ids: + - id: "987654321" + comment: Digital Lake (IN-6856) +- FailedAuthorizationsPerDomainPerAccount: + burst: 10 + count: 10 + period: 1h0m0s + ids: + - id: "543219876" + comment: Big Mart (IN-6949) +- NewOrdersPerAccount: + burst: 3000 + count: 3000 + period: 3h0m0s + ids: + - id: "123456789" + comment: Galaxy Hoster (IN-8180) +- NewOrdersPerAccount: + burst: 1000 + count: 1000 + period: 3h0m0s + ids: + - id: "543219876" + comment: Big Mart (IN-8180) + - id: "987654321" + comment: Buy More (IN-10057) +- NewRegistrationsPerIPAddress: + burst: 100000 + count: 100000 + period: 3h0m0s + ids: + - id: 2600:1f1c:5e0:e702:ca06:d2a3:c7ce:a02e + comment: example.org IN-2395 + - id: 55.66.77.88 + comment: example.org IN-2395 +- NewRegistrationsPerIPAddress: + burst: 200 + count: 200 + period: 3h0m0s + ids: + - id: 11.22.33.44 + comment: example.net (IN-1583)` + + expectCSV := ` +name,id,count,burst,period,comment +CertificatesPerDomain,example.com,5000,5000,168h0m0s,IN-10057 +CertificatesPerDomain,example.net,5000,5000,168h0m0s,IN-10057 +CertificatesPerDomain,example.org,300,300,168h0m0s,IN-10057 +CertificatesPerDomainPerAccount,123456789,12000,12000,168h0m0s,Affluent (IN-8322) +CertificatesPerDomainPerAccount,543219876,6000,6000,168h0m0s,Affluent (IN-8322) +CertificatesPerDomainPerAccount,987654321,6000,6000,168h0m0s,Affluent (IN-8322) +CertificatesPerFQDNSet,7c956936126b492845ddb48f4d220034509e7c0ad54ed2c1ba2650406846d9c3,50,50,168h0m0s,IN-6843 +CertificatesPerFQDNSet,394e82811f52e2da38b970afdb21c9bc9af81060939c690183c00fce37408738,24,24,168h0m0s,IN-6006 +FailedAuthorizationsPerDomainPerAccount,123456789,250,250,1h0m0s,Digital Lake (IN-6736) +FailedAuthorizationsPerDomainPerAccount,987654321,50,50,1h0m0s,Digital Lake (IN-6856) +FailedAuthorizationsPerDomainPerAccount,543219876,10,10,1h0m0s,Big Mart (IN-6949) +NewOrdersPerAccount,123456789,3000,3000,3h0m0s,Galaxy Hoster (IN-8180) +NewOrdersPerAccount,543219876,1000,1000,3h0m0s,Big Mart (IN-8180) +NewOrdersPerAccount,987654321,1000,1000,3h0m0s,Buy More (IN-10057) +NewRegistrationsPerIPAddress,2600:1f1c:5e0:e702:ca06:d2a3:c7ce:a02e,100000,100000,3h0m0s,example.org IN-2395 +NewRegistrationsPerIPAddress,55.66.77.88,100000,100000,3h0m0s,example.org IN-2395 +NewRegistrationsPerIPAddress,11.22.33.44,200,200,3h0m0s,example.net (IN-1583) +` + tempDir := t.TempDir() + tempFile := filepath.Join(tempDir, "overrides.yaml") + + err := os.WriteFile(tempFile, []byte(input), 0644) + test.AssertNotError(t, err, "writing temp overrides.yaml") + + original, err := LoadOverridesByBucketKey(tempFile) + test.AssertNotError(t, err, "loading overrides") + test.Assert(t, len(original) > 0, "expected at least one override loaded") + + dumpFile := filepath.Join(tempDir, "dumped.yaml") + err = DumpOverrides(dumpFile, original) + test.AssertNotError(t, err, "dumping overrides") + + dumped, err := os.ReadFile(dumpFile) + test.AssertNotError(t, err, "reading dumped overrides file") + test.AssertEquals(t, strings.TrimLeft(string(dumped), "\n"), strings.TrimLeft(expectCSV, "\n")) +} diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index b7a195028..a04a592a2 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -104,13 +104,13 @@ func (d *Decision) Result(now time.Time) error { // There is no case for FailedAuthorizationsForPausingPerDomainPerAccount // because the RA will pause clients who exceed that ratelimit. - switch d.transaction.limit.name { + switch d.transaction.limit.Name { case NewRegistrationsPerIPAddress: return berrors.RegistrationsPerIPAddressError( retryAfter, "too many new registrations (%d) from this IP address in the last %s, retry after %s", - d.transaction.limit.burst, - d.transaction.limit.period.Duration, + d.transaction.limit.Burst, + d.transaction.limit.Period.Duration, retryAfterTs, ) @@ -118,16 +118,16 @@ func (d *Decision) Result(now time.Time) error { return berrors.RegistrationsPerIPv6RangeError( retryAfter, "too many new registrations (%d) from this /48 subnet of IPv6 addresses in the last %s, retry after %s", - d.transaction.limit.burst, - d.transaction.limit.period.Duration, + d.transaction.limit.Burst, + d.transaction.limit.Period.Duration, retryAfterTs, ) case NewOrdersPerAccount: return berrors.NewOrdersPerAccountError( retryAfter, "too many new orders (%d) from this account in the last %s, retry after %s", - d.transaction.limit.burst, - d.transaction.limit.period.Duration, + d.transaction.limit.Burst, + d.transaction.limit.Period.Duration, retryAfterTs, ) @@ -141,9 +141,9 @@ func (d *Decision) Result(now time.Time) error { return berrors.FailedAuthorizationsPerDomainPerAccountError( retryAfter, "too many failed authorizations (%d) for %q in the last %s, retry after %s", - d.transaction.limit.burst, + d.transaction.limit.Burst, identValue, - d.transaction.limit.period.Duration, + d.transaction.limit.Period.Duration, retryAfterTs, ) @@ -157,9 +157,9 @@ func (d *Decision) Result(now time.Time) error { return berrors.CertificatesPerDomainError( retryAfter, "too many certificates (%d) already issued for %q in the last %s, retry after %s", - d.transaction.limit.burst, + d.transaction.limit.Burst, domainOrCIDR, - d.transaction.limit.period.Duration, + d.transaction.limit.Period.Duration, retryAfterTs, ) @@ -167,8 +167,8 @@ func (d *Decision) Result(now time.Time) error { return berrors.CertificatesPerFQDNSetError( retryAfter, "too many certificates (%d) already issued for this exact set of identifiers in the last %s, retry after %s", - d.transaction.limit.burst, - d.transaction.limit.period.Duration, + d.transaction.limit.Burst, + d.transaction.limit.Period.Duration, retryAfterTs, ) @@ -346,7 +346,7 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision totalLatency := l.clk.Since(start) perTxnLatency := totalLatency / time.Duration(len(txnOutcomes)) for txn, outcome := range txnOutcomes { - l.spendLatency.WithLabelValues(txn.limit.name.String(), outcome).Observe(perTxnLatency.Seconds()) + l.spendLatency.WithLabelValues(txn.limit.Name.String(), outcome).Observe(perTxnLatency.Seconds()) } return batchDecision, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index eb6f938b6..7aabd8593 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -464,10 +464,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 5 * time.Second, transaction: Transaction{ - limit: &limit{ - name: NewRegistrationsPerIPAddress, - burst: 10, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: NewRegistrationsPerIPAddress, + Burst: 10, + Period: config.Duration{Duration: time.Hour}, }, }, }, @@ -480,10 +480,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 10 * time.Second, transaction: Transaction{ - limit: &limit{ - name: NewRegistrationsPerIPv6Range, - burst: 5, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: NewRegistrationsPerIPv6Range, + Burst: 5, + Period: config.Duration{Duration: time.Hour}, }, }, }, @@ -496,10 +496,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 10 * time.Second, transaction: Transaction{ - limit: &limit{ - name: NewOrdersPerAccount, - burst: 2, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: NewOrdersPerAccount, + Burst: 2, + Period: config.Duration{Duration: time.Hour}, }, }, }, @@ -512,10 +512,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 15 * time.Second, transaction: Transaction{ - limit: &limit{ - name: FailedAuthorizationsPerDomainPerAccount, - burst: 7, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: FailedAuthorizationsPerDomainPerAccount, + Burst: 7, + Period: config.Duration{Duration: time.Hour}, }, bucketKey: "4:12345:example.com", }, @@ -529,10 +529,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 20 * time.Second, transaction: Transaction{ - limit: &limit{ - name: CertificatesPerDomain, - burst: 3, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: CertificatesPerDomain, + Burst: 3, + Period: config.Duration{Duration: time.Hour}, }, bucketKey: "5:example.org", }, @@ -546,10 +546,10 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 20 * time.Second, transaction: Transaction{ - limit: &limit{ - name: CertificatesPerDomainPerAccount, - burst: 3, - period: config.Duration{Duration: time.Hour}, + limit: &Limit{ + Name: CertificatesPerDomainPerAccount, + Burst: 3, + Period: config.Duration{Duration: time.Hour}, }, bucketKey: "6:12345678:example.net", }, @@ -563,8 +563,8 @@ func TestRateLimitError(t *testing.T) { allowed: false, retryIn: 30 * time.Second, transaction: Transaction{ - limit: &limit{ - name: 9999999, + limit: &Limit{ + Name: 9999999, }, }, }, diff --git a/ratelimits/names.go b/ratelimits/names.go index e23c03c6d..1ce3c514c 100644 --- a/ratelimits/names.go +++ b/ratelimits/names.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/letsencrypt/boulder/iana" + "github.com/letsencrypt/boulder/identifier" "github.com/letsencrypt/boulder/policy" ) @@ -17,8 +18,9 @@ import ( // IMPORTANT: If you add or remove a limit Name, you MUST update: // - the string representation of the Name in nameToString, // - the validators for that name in validateIdForName(), -// - the transaction constructors for that name in bucket.go, and -// - the Subscriber facing error message in ErrForDecision(). +// - the transaction constructors for that name in bucket.go +// - the Subscriber facing error message in ErrForDecision(), and +// - the case in BuildBucketKey() for that name. type Name int const ( @@ -206,7 +208,7 @@ func validateRegIdIdentValue(id string) error { // validateDomainOrCIDR validates that the provided string is either a domain // name or an IP address. IPv6 addresses must be the lowest address in their // /64, i.e. their last 64 bits must be zero. -func validateDomainOrCIDR(id string) error { +func validateDomainOrCIDR(limit Name, id string) error { domainErr := policy.ValidDomain(id) if domainErr == nil { // This is a valid domain. @@ -222,14 +224,13 @@ func validateDomainOrCIDR(id string) error { return fmt.Errorf("invalid IP address %q, must be in canonical form (%q)", id, ip.String()) } - prefix, prefixErr := coveringPrefix(ip) + prefix, prefixErr := coveringIPPrefix(limit, ip) if prefixErr != nil { return fmt.Errorf("invalid IP address %q, couldn't determine prefix: %w", id, prefixErr) } if prefix.Addr() != ip { return fmt.Errorf("invalid IP address %q, must be the lowest address in its prefix (%q)", id, prefix.Addr().String()) } - return iana.IsReservedPrefix(prefix) } @@ -237,7 +238,7 @@ func validateDomainOrCIDR(id string) error { // 'regId:domainOrCIDR', where domainOrCIDR is either a domain name or an IP // address. IPv6 addresses must be the lowest address in their /64, i.e. their // last 64 bits must be zero. -func validateRegIdDomainOrCIDR(id string) error { +func validateRegIdDomainOrCIDR(limit Name, id string) error { regIdDomainOrCIDR := strings.Split(id, ":") if len(regIdDomainOrCIDR) != 2 { return fmt.Errorf( @@ -248,7 +249,7 @@ func validateRegIdDomainOrCIDR(id string) error { return fmt.Errorf( "invalid regId, %q must be formatted 'regId:domainOrCIDR'", id) } - err = validateDomainOrCIDR(regIdDomainOrCIDR[1]) + err = validateDomainOrCIDR(limit, regIdDomainOrCIDR[1]) if err != nil { return fmt.Errorf("invalid domainOrCIDR, %q must be formatted 'regId:domainOrCIDR': %w", id, err) } @@ -301,7 +302,7 @@ func validateIdForName(name Name, id string) error { case CertificatesPerDomainPerAccount: if strings.Contains(id, ":") { // 'enum:regId:domainOrCIDR' for transaction - return validateRegIdDomainOrCIDR(id) + return validateRegIdDomainOrCIDR(name, id) } else { // 'enum:regId' for overrides return validateRegId(id) @@ -309,7 +310,7 @@ func validateIdForName(name Name, id string) error { case CertificatesPerDomain: // 'enum:domainOrCIDR' - return validateDomainOrCIDR(id) + return validateDomainOrCIDR(name, id) case CertificatesPerFQDNSet: // 'enum:fqdnSet' @@ -333,8 +334,8 @@ func validateIdForName(name Name, id string) error { } } -// stringToName is a map of string names to Name values. -var stringToName = func() map[string]Name { +// StringToName is a map of string names to Name values. +var StringToName = func() map[string]Name { m := make(map[string]Name, len(nameToString)) for k, v := range nameToString { m[v] = k @@ -342,11 +343,94 @@ var stringToName = func() map[string]Name { return m }() -// limitNames is a slice of all rate limit names. -var limitNames = func() []string { +// LimitNames is a slice of all rate limit names. +var LimitNames = func() []string { names := make([]string, 0, len(nameToString)) for _, v := range nameToString { names = append(names, v) } return names }() + +// BuildBucketKey builds a bucketKey for the given rate limit name from the +// provided components. It returns an error if the name is not valid or if the +// components are not valid for the given name. +func BuildBucketKey(name Name, regId int64, singleIdent identifier.ACMEIdentifier, setOfIdents identifier.ACMEIdentifiers, subscriberIP netip.Addr) (string, error) { + makeMissingErr := func(field string) error { + return fmt.Errorf("%s is required for limit %s (enum: %s)", field, name, name.EnumString()) + } + + switch name { + case NewRegistrationsPerIPAddress: + if !subscriberIP.IsValid() { + return "", makeMissingErr("subscriberIP") + } + return newIPAddressBucketKey(name, subscriberIP), nil + + case NewRegistrationsPerIPv6Range: + if !subscriberIP.IsValid() { + return "", makeMissingErr("subscriberIP") + } + prefix, err := coveringIPPrefix(name, subscriberIP) + if err != nil { + return "", err + } + return newIPv6RangeCIDRBucketKey(name, prefix), nil + + case NewOrdersPerAccount: + if regId == 0 { + return "", makeMissingErr("regId") + } + return newRegIdBucketKey(name, regId), nil + + case CertificatesPerDomain: + if singleIdent.Value == "" { + return "", makeMissingErr("singleIdent") + } + coveringIdent, err := coveringIdentifier(name, singleIdent) + if err != nil { + return "", err + } + return newDomainOrCIDRBucketKey(name, coveringIdent), nil + + case CertificatesPerDomainPerAccount: + if singleIdent.Value != "" { + if regId == 0 { + return "", makeMissingErr("regId") + } + // Default: use 'enum:regId:identValue' bucket key format. + coveringIdent, err := coveringIdentifier(name, singleIdent) + if err != nil { + return "", err + } + return NewRegIdIdentValueBucketKey(name, regId, coveringIdent), nil + } + if regId == 0 { + return "", makeMissingErr("regId") + } + // Override: use 'enum:regId' bucket key format. + return newRegIdBucketKey(name, regId), nil + + case CertificatesPerFQDNSet: + if len(setOfIdents) == 0 { + return "", makeMissingErr("setOfIdents") + } + return newFQDNSetBucketKey(name, setOfIdents), nil + + case FailedAuthorizationsPerDomainPerAccount, FailedAuthorizationsForPausingPerDomainPerAccount: + if singleIdent.Value != "" { + if regId == 0 { + return "", makeMissingErr("regId") + } + // Default: use 'enum:regId:identValue' bucket key format. + return NewRegIdIdentValueBucketKey(name, regId, singleIdent.Value), nil + } + if regId == 0 { + return "", makeMissingErr("regId") + } + // Override: use 'enum:regId' bucket key format. + return newRegIdBucketKey(name, regId), nil + } + + return "", fmt.Errorf("unknown limit enum %s", name.EnumString()) +} diff --git a/ratelimits/names_test.go b/ratelimits/names_test.go index 93e710643..1c65c936a 100644 --- a/ratelimits/names_test.go +++ b/ratelimits/names_test.go @@ -2,8 +2,11 @@ package ratelimits import ( "fmt" + "net/netip" + "strings" "testing" + "github.com/letsencrypt/boulder/identifier" "github.com/letsencrypt/boulder/test" ) @@ -293,3 +296,202 @@ func TestValidateIdForName(t *testing.T) { }) } } + +func TestBuildBucketKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name Name + desc string + regId int64 + singleIdent identifier.ACMEIdentifier + setOfIdents identifier.ACMEIdentifiers + subscriberIP netip.Addr + expectErrContains string + outputTest func(t *testing.T, key string) + }{ + // NewRegistrationsPerIPAddress + { + name: NewRegistrationsPerIPAddress, + desc: "valid subscriber IPv4 address", + subscriberIP: netip.MustParseAddr("1.2.3.4"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1.2.3.4", NewRegistrationsPerIPAddress), key) + }, + }, + { + name: NewRegistrationsPerIPAddress, + desc: "valid subscriber IPv6 address", + subscriberIP: netip.MustParseAddr("2001:db8::1"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:2001:db8::1", NewRegistrationsPerIPAddress), key) + }, + }, + // NewRegistrationsPerIPv6Range + { + name: NewRegistrationsPerIPv6Range, + desc: "valid subscriber IPv6 address", + subscriberIP: netip.MustParseAddr("2001:db8:abcd:12::1"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:2001:db8:abcd::/48", NewRegistrationsPerIPv6Range), key) + }, + }, + { + name: NewRegistrationsPerIPv6Range, + desc: "subscriber IPv4 given for subscriber IPv6 range limit", + subscriberIP: netip.MustParseAddr("1.2.3.4"), + expectErrContains: "requires an IPv6 address", + }, + + // NewOrdersPerAccount + { + name: NewOrdersPerAccount, + desc: "valid registration ID", + regId: 1337, + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337", NewOrdersPerAccount), key) + }, + }, + { + name: NewOrdersPerAccount, + desc: "registration ID missing", + expectErrContains: "regId is required", + }, + + // CertificatesPerDomain + { + name: CertificatesPerDomain, + desc: "DNS identifier to eTLD+1", + singleIdent: identifier.NewDNS("www.example.com"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:example.com", CertificatesPerDomain), key) + }, + }, + { + name: CertificatesPerDomain, + desc: "valid IPv4 address used as identifier", + singleIdent: identifier.NewIP(netip.MustParseAddr("5.6.7.8")), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:5.6.7.8/32", CertificatesPerDomain), key) + }, + }, + { + name: CertificatesPerDomain, + desc: "valid IPv6 address used as identifier", + singleIdent: identifier.NewIP(netip.MustParseAddr("2001:db8::1")), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:2001:db8::/64", CertificatesPerDomain), key) + }, + }, + { + name: CertificatesPerDomain, + desc: "identifier missing", + expectErrContains: "singleIdent is required", + }, + + // CertificatesPerFQDNSet + { + name: CertificatesPerFQDNSet, + desc: "multiple valid DNS identifiers", + setOfIdents: identifier.NewDNSSlice([]string{"example.com", "example.org"}), + outputTest: func(t *testing.T, key string) { + if !strings.HasPrefix(key, fmt.Sprintf("%d:", CertificatesPerFQDNSet)) { + t.Errorf("expected key to start with %d: got %s", CertificatesPerFQDNSet, key) + } + }, + }, + { + name: CertificatesPerFQDNSet, + desc: "multiple valid DNS and IP identifiers", + setOfIdents: identifier.ACMEIdentifiers{identifier.NewDNS("example.net"), identifier.NewIP(netip.MustParseAddr("5.6.7.8")), identifier.NewIP(netip.MustParseAddr("2001:db8::1"))}, + outputTest: func(t *testing.T, key string) { + if !strings.HasPrefix(key, fmt.Sprintf("%d:", CertificatesPerFQDNSet)) { + t.Errorf("expected key to start with %d: got %s", CertificatesPerFQDNSet, key) + } + }, + }, + { + name: CertificatesPerFQDNSet, + desc: "identifiers missing", + expectErrContains: "setOfIdents is required", + }, + + // CertificatesPerDomainPerAccount + { + name: CertificatesPerDomainPerAccount, + desc: "only registration ID", + regId: 1337, + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337", CertificatesPerDomainPerAccount), key) + }, + }, + { + name: CertificatesPerDomainPerAccount, + desc: "registration ID and single DNS identifier provided", + regId: 1337, + singleIdent: identifier.NewDNS("example.com"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337:example.com", CertificatesPerDomainPerAccount), key) + }, + }, + { + name: CertificatesPerDomainPerAccount, + desc: "single DNS identifier provided without registration ID", + singleIdent: identifier.NewDNS("example.com"), + expectErrContains: "regId is required", + }, + + // FailedAuthorizationsPerDomainPerAccount + { + name: FailedAuthorizationsPerDomainPerAccount, + desc: "registration ID and single DNS identifier", + regId: 1337, + singleIdent: identifier.NewDNS("example.com"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337:example.com", FailedAuthorizationsPerDomainPerAccount), key) + }, + }, + { + name: FailedAuthorizationsPerDomainPerAccount, + desc: "only registration ID", + regId: 1337, + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337", FailedAuthorizationsPerDomainPerAccount), key) + }, + }, + + // FailedAuthorizationsForPausingPerDomainPerAccount + { + name: FailedAuthorizationsForPausingPerDomainPerAccount, + desc: "registration ID and single DNS identifier", + regId: 1337, + singleIdent: identifier.NewDNS("example.com"), + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337:example.com", FailedAuthorizationsForPausingPerDomainPerAccount), key) + }, + }, + { + name: FailedAuthorizationsForPausingPerDomainPerAccount, + desc: "only registration ID", + regId: 1337, + outputTest: func(t *testing.T, key string) { + test.AssertEquals(t, fmt.Sprintf("%d:1337", FailedAuthorizationsForPausingPerDomainPerAccount), key) + }, + }, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%s/%s", tc.name, tc.desc), func(t *testing.T) { + t.Parallel() + + key, err := BuildBucketKey(tc.name, tc.regId, tc.singleIdent, tc.setOfIdents, tc.subscriberIP) + if tc.expectErrContains != "" { + test.AssertError(t, err, "expected error") + test.AssertContains(t, err.Error(), tc.expectErrContains) + return + } + test.AssertNotError(t, err, "unexpected error") + tc.outputTest(t, key) + }) + } +} diff --git a/ratelimits/testdata/busted_override_burst_0.yml b/ratelimits/testdata/busted_override_burst_0.yml index 9c74e16ac..9110fc1aa 100644 --- a/ratelimits/testdata/busted_override_burst_0.yml +++ b/ratelimits/testdata/busted_override_burst_0.yml @@ -3,5 +3,5 @@ count: 40 period: 1s ids: - - id: 10.0.0.2 + - id: 55.66.77.88 comment: Foo diff --git a/ratelimits/transaction.go b/ratelimits/transaction.go index adbed90c7..0e3d25b9c 100644 --- a/ratelimits/transaction.go +++ b/ratelimits/transaction.go @@ -16,32 +16,25 @@ var ErrInvalidCost = fmt.Errorf("invalid cost, must be >= 0") // ErrInvalidCostOverLimit indicates that the cost specified was > limit.Burst. var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst") -// newIPAddressBucketKey validates and returns a bucketKey for limits that use +// newIPAddressBucketKey returns a bucketKey for limits that use // the 'enum:ipAddress' bucket key format. -func newIPAddressBucketKey(name Name, ip netip.Addr) string { //nolint:unparam // Only one named rate limit uses this helper +func newIPAddressBucketKey(name Name, ip netip.Addr) string { return joinWithColon(name.EnumString(), ip.String()) } -// newIPv6RangeCIDRBucketKey validates and returns a bucketKey for limits that +// newIPv6RangeCIDRBucketKey returns a bucketKey for limits that // use the 'enum:ipv6RangeCIDR' bucket key format. -func newIPv6RangeCIDRBucketKey(name Name, ip netip.Addr) (string, error) { - if ip.Is4() { - return "", fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) - } - prefix, err := ip.Prefix(48) - if err != nil { - return "", fmt.Errorf("invalid IPv6 address, can't calculate prefix of %q: %s", ip.String(), err) - } - return joinWithColon(name.EnumString(), prefix.String()), nil +func newIPv6RangeCIDRBucketKey(name Name, prefix netip.Prefix) string { + return joinWithColon(name.EnumString(), prefix.String()) } -// newRegIdBucketKey validates and returns a bucketKey for limits that use the +// newRegIdBucketKey returns a bucketKey for limits that use the // 'enum:regId' bucket key format. func newRegIdBucketKey(name Name, regId int64) string { return joinWithColon(name.EnumString(), strconv.FormatInt(regId, 10)) } -// newDomainOrCIDRBucketKey validates and returns a bucketKey for limits that use +// newDomainOrCIDRBucketKey returns a bucketKey for limits that use // the 'enum:domainOrCIDR' bucket key formats. func newDomainOrCIDRBucketKey(name Name, domainOrCIDR string) string { return joinWithColon(name.EnumString(), domainOrCIDR) @@ -56,7 +49,7 @@ func NewRegIdIdentValueBucketKey(name Name, regId int64, orderIdent string) stri // newFQDNSetBucketKey validates and returns a bucketKey for limits that use the // 'enum:fqdnSet' bucket key format. -func newFQDNSetBucketKey(name Name, orderIdents identifier.ACMEIdentifiers) string { //nolint: unparam // Only one named rate limit uses this helper +func newFQDNSetBucketKey(name Name, orderIdents identifier.ACMEIdentifiers) string { return joinWithColon(name.EnumString(), fmt.Sprintf("%x", core.HashIdentifiers(orderIdents))) } @@ -80,7 +73,7 @@ func newFQDNSetBucketKey(name Name, orderIdents identifier.ACMEIdentifiers) stri // it would fail validateTransaction (for instance because cost and burst are zero). type Transaction struct { bucketKey string - limit *limit + limit *Limit cost int64 check bool spend bool @@ -102,7 +95,7 @@ func validateTransaction(txn Transaction) (Transaction, error) { if txn.cost < 0 { return Transaction{}, ErrInvalidCost } - if txn.limit.burst == 0 { + if txn.limit.Burst == 0 { // This should never happen. If the limit was loaded from a file, // Burst was validated then. If this is a zero-valued Transaction // (that is, an allow-only transaction), then validateTransaction @@ -110,13 +103,13 @@ func validateTransaction(txn Transaction) (Transaction, error) { // valid. return Transaction{}, fmt.Errorf("invalid limit, burst must be > 0") } - if txn.cost > txn.limit.burst { + if txn.cost > txn.limit.Burst { return Transaction{}, ErrInvalidCostOverLimit } return txn, nil } -func newTransaction(limit *limit, bucketKey string, cost int64) (Transaction, error) { +func newTransaction(limit *Limit, bucketKey string, cost int64) (Transaction, error) { return validateTransaction(Transaction{ bucketKey: bucketKey, limit: limit, @@ -126,7 +119,7 @@ func newTransaction(limit *limit, bucketKey string, cost int64) (Transaction, er }) } -func newCheckOnlyTransaction(limit *limit, bucketKey string, cost int64) (Transaction, error) { +func newCheckOnlyTransaction(limit *Limit, bucketKey string, cost int64) (Transaction, error) { return validateTransaction(Transaction{ bucketKey: bucketKey, limit: limit, @@ -135,7 +128,7 @@ func newCheckOnlyTransaction(limit *limit, bucketKey string, cost int64) (Transa }) } -func newSpendOnlyTransaction(limit *limit, bucketKey string, cost int64) (Transaction, error) { +func newSpendOnlyTransaction(limit *Limit, bucketKey string, cost int64) (Transaction, error) { return validateTransaction(Transaction{ bucketKey: bucketKey, limit: limit, @@ -197,10 +190,12 @@ func (builder *TransactionBuilder) registrationsPerIPAddressTransaction(ip netip // NewRegistrationsPerIPv6Range limit for the /48 IPv6 range which contains the // provided IPv6 address. func (builder *TransactionBuilder) registrationsPerIPv6RangeTransaction(ip netip.Addr) (Transaction, error) { - bucketKey, err := newIPv6RangeCIDRBucketKey(NewRegistrationsPerIPv6Range, ip) + prefix, err := coveringIPPrefix(NewRegistrationsPerIPv6Range, ip) if err != nil { - return Transaction{}, err + return Transaction{}, fmt.Errorf("computing covering prefix for %q: %w", ip, err) } + bucketKey := newIPv6RangeCIDRBucketKey(NewRegistrationsPerIPv6Range, prefix) + limit, err := builder.getLimit(NewRegistrationsPerIPv6Range, bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { diff --git a/ratelimits/transaction_test.go b/ratelimits/transaction_test.go index e1e37bf8f..a0fce990d 100644 --- a/ratelimits/transaction_test.go +++ b/ratelimits/transaction_test.go @@ -223,7 +223,7 @@ func TestNewTransactionBuilder(t *testing.T) { newRegDefault, ok := tb.limitRegistry.defaults[NewRegistrationsPerIPAddress.EnumString()] test.Assert(t, ok, "NewRegistrationsPerIPAddress was not populated in registry") - test.AssertEquals(t, newRegDefault.burst, expectedBurst) - test.AssertEquals(t, newRegDefault.count, expectedCount) - test.AssertEquals(t, newRegDefault.period, expectedPeriod) + test.AssertEquals(t, newRegDefault.Burst, expectedBurst) + test.AssertEquals(t, newRegDefault.Count, expectedCount) + test.AssertEquals(t, newRegDefault.Period, expectedPeriod) } diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index 7999b80d0..17921c5ad 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -16,57 +16,106 @@ func joinWithColon(args ...string) string { return strings.Join(args, ":") } -// coveringIdentifiers transforms a slice of ACMEIdentifiers into strings of -// their "covering" identifiers, for the CertificatesPerDomain limit. It also -// de-duplicates the output. For DNS identifiers, this is eTLD+1's; exact public -// suffix matches are included. For IP address identifiers, this is the address -// (/32) for IPv4, or the /64 prefix for IPv6, in CIDR notation. +// coveringIdentifiers returns the set of "covering" identifiers used to enforce +// the CertificatesPerDomain rate limit. For DNS names, this is the eTLD+1 as +// determined by the Public Suffix List; exact public suffix matches are +// preserved. For IP addresses, the covering prefix is /32 for IPv4 and /64 for +// IPv6. This groups requests by registered domain or address block to match the +// scope of the limit. The result is deduplicated and lowercased. If the +// identifier type is unsupported, an error is returned. func coveringIdentifiers(idents identifier.ACMEIdentifiers) ([]string, error) { var covers []string for _, ident := range idents { - switch ident.Type { - case identifier.TypeDNS: - domain, err := publicsuffix.Domain(ident.Value) - if err != nil { - if err.Error() == fmt.Sprintf("%s is a suffix", ident.Value) { - // If the public suffix is the domain itself, that's fine. - // Include the original name in the result. - covers = append(covers, ident.Value) - continue - } else { - return nil, err - } - } - covers = append(covers, domain) - case identifier.TypeIP: - ip, err := netip.ParseAddr(ident.Value) - if err != nil { - return nil, err - } - prefix, err := coveringPrefix(ip) - if err != nil { - return nil, err - } - covers = append(covers, prefix.String()) + cover, err := coveringIdentifier(CertificatesPerDomain, ident) + if err != nil { + return nil, err } + covers = append(covers, cover) } return core.UniqueLowerNames(covers), nil } -// coveringPrefix transforms a netip.Addr into its "covering" prefix, for the -// CertificatesPerDomain limit. For IPv4, this is the IP address (/32). For -// IPv6, this is the /64 that contains the address. -func coveringPrefix(addr netip.Addr) (netip.Prefix, error) { - var bits int - if addr.Is4() { - bits = 32 - } else { - bits = 64 +// coveringIdentifier returns the "covering" identifier used to enforce the +// CertificatesPerDomain, CertificatesPerDomainPerAccount, and +// NewRegistrationsPerIPv6Range rate limits. For DNS names, this is the eTLD+1 +// as determined by the Public Suffix List; exact public suffix matches are +// preserved. For IP addresses, the covering prefix depends on the limit: +// +// - CertificatesPerDomain and CertificatesPerDomainPerAccount: +// - /32 for IPv4 +// - /64 for IPv6 +// +// - NewRegistrationsPerIPv6Range: +// - /48 for IPv6 only +// +// This groups requests by registered domain or address block to match the scope +// of each limit. The result is deduplicated and lowercased. If the identifier +// type or limit is unsupported, an error is returned. +func coveringIdentifier(limit Name, ident identifier.ACMEIdentifier) (string, error) { + switch ident.Type { + case identifier.TypeDNS: + domain, err := publicsuffix.Domain(ident.Value) + if err != nil { + if err.Error() == fmt.Sprintf("%s is a suffix", ident.Value) { + // If the public suffix is the domain itself, that's fine. + // Include the original name in the result. + return ident.Value, nil + } + return "", err + } + return domain, nil + case identifier.TypeIP: + ip, err := netip.ParseAddr(ident.Value) + if err != nil { + return "", err + } + prefix, err := coveringIPPrefix(limit, ip) + if err != nil { + return "", err + } + return prefix.String(), nil } - prefix, err := addr.Prefix(bits) - if err != nil { - // This should be impossible because bits is hardcoded. - return netip.Prefix{}, err - } - return prefix, nil + return "", fmt.Errorf("unsupported identifier type: %s", ident.Type) +} + +// coveringIPPrefix returns the "covering" IP prefix used to enforce the +// CertificatesPerDomain, CertificatesPerDomainPerAccount, and +// NewRegistrationsPerIPv6Range rate limits. The prefix length depends on the +// limit and IP version: +// +// - CertificatesPerDomain and CertificatesPerDomainPerAccount: +// - /32 for IPv4 +// - /64 for IPv6 +// +// - NewRegistrationsPerIPv6Range: +// - /48 for IPv6 only +// +// This groups requests by address block to match the scope of each limit. If +// the limit does not require a covering prefix, an error is returned. +func coveringIPPrefix(limit Name, addr netip.Addr) (netip.Prefix, error) { + switch limit { + case CertificatesPerDomain, CertificatesPerDomainPerAccount: + var bits int + if addr.Is4() { + bits = 32 + } else { + bits = 64 + } + prefix, err := addr.Prefix(bits) + if err != nil { + return netip.Prefix{}, fmt.Errorf("building covering prefix for %s: %w", addr, err) + } + return prefix, nil + + case NewRegistrationsPerIPv6Range: + if !addr.Is6() { + return netip.Prefix{}, fmt.Errorf("limit %s requires an IPv6 address, got %s", limit, addr) + } + prefix, err := addr.Prefix(48) + if err != nil { + return netip.Prefix{}, fmt.Errorf("building covering prefix for %s: %w", addr, err) + } + return prefix, nil + } + return netip.Prefix{}, fmt.Errorf("limit %s does not require a covering prefix", limit) }