Initial implementation of key-value rate limits (#6947)

This design seeks to reduce read-pressure on our DB by moving rate limit
tabulation to a key-value datastore. This PR provides the following:

- (README.md) a short guide to the schemas, formats, and concepts
introduced in this PR
- (source.go) an interface for storing, retrieving, and resetting a
subscriber bucket
- (name.go) an enumeration of all defined rate limits
- (limit.go) a schema for defining default limits and per-subscriber
overrides
- (limiter.go) a high-level API for interacting with key-value rate
limits
- (gcra.go) an implementation of the Generic Cell Rate Algorithm, a
leaky bucket-style scheduling algorithm, used to calculate the present
or future capacity of a subscriber bucket using spend and refund
operations

Note: the included source implementation is test-only and currently
accomplished using a simple in-memory map protected by a mutex,
implementations using Redis and potentially other data stores will
follow.

Part of #5545
This commit is contained in:
Samantha 2023-07-21 12:57:18 -04:00 committed by GitHub
parent e955494955
commit 055f620c4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1969 additions and 51 deletions

View File

@ -304,7 +304,7 @@ func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509
}
func (m *mailer) certIsRenewed(ctx context.Context, names []string, issued time.Time) (bool, error) {
namehash := sa.HashNames(names)
namehash := core.HashNames(names)
var present bool
err := m.dbMap.SelectOne(

View File

@ -353,7 +353,7 @@ func TestNoContactCertIsRenewed(t *testing.T) {
setupDBMap, err := sa.DBMapForTest(vars.DBConnSAFullPerms)
test.AssertNotError(t, err, "setting up DB")
err = setupDBMap.Insert(ctx, &core.FQDNSet{
SetHash: sa.HashNames(names),
SetHash: core.HashNames(names),
Serial: core.SerialToString(serial2),
Issued: testCtx.fc.Now().Add(time.Hour),
Expires: expires.Add(time.Hour),
@ -580,13 +580,13 @@ func addExpiringCerts(t *testing.T, ctx *testCtx) []certDERWithRegID {
test.AssertNotError(t, err, "creating cert D")
fqdnStatusD := &core.FQDNSet{
SetHash: sa.HashNames(certDNames),
SetHash: core.HashNames(certDNames),
Serial: serial4String,
Issued: ctx.fc.Now().AddDate(0, 0, -87),
Expires: ctx.fc.Now().AddDate(0, 0, 3),
}
fqdnStatusDRenewed := &core.FQDNSet{
SetHash: sa.HashNames(certDNames),
SetHash: core.HashNames(certDNames),
Serial: serial5String,
Issued: ctx.fc.Now().AddDate(0, 0, -3),
Expires: ctx.fc.Now().AddDate(0, 0, 87),
@ -747,7 +747,7 @@ func TestCertIsRenewed(t *testing.T) {
t.Fatal(err)
}
fqdnStatus := &core.FQDNSet{
SetHash: sa.HashNames(testData.DNS),
SetHash: core.HashNames(testData.DNS),
Serial: testData.stringSerial,
Issued: testData.NotBefore,
Expires: testData.NotAfter,

View File

@ -242,6 +242,14 @@ func UniqueLowerNames(names []string) (unique []string) {
return
}
// HashNames returns a hash of the names requested. This is intended for use
// when interacting with the orderFqdnSets table and rate limiting.
func HashNames(names []string) []byte {
names = UniqueLowerNames(names)
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
return hash[:]
}
// LoadCert loads a PEM certificate specified by filename or returns an error
func LoadCert(filename string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(filename)

View File

@ -1,6 +1,7 @@
package core
import (
"bytes"
"encoding/json"
"fmt"
"math"
@ -206,3 +207,30 @@ func TestRetryBackoff(t *testing.T) {
assertBetween(float64(backoff), float64(expected)*0.8, float64(expected)*1.2)
}
func TestHashNames(t *testing.T) {
// Test that it is deterministic
h1 := HashNames([]string{"a"})
h2 := HashNames([]string{"a"})
test.AssertByteEquals(t, h1, h2)
// Test that it differentiates
h1 = HashNames([]string{"a"})
h2 = HashNames([]string{"b"})
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
// Test that it is not subject to ordering
h1 = HashNames([]string{"a", "b"})
h2 = HashNames([]string{"b", "a"})
test.AssertByteEquals(t, h1, h2)
// Test that it is not subject to case
h1 = HashNames([]string{"a", "b"})
h2 = HashNames([]string{"A", "B"})
test.AssertByteEquals(t, h1, h2)
// Test that it is not subject to duplication
h1 = HashNames([]string{"a", "a"})
h2 = HashNames([]string{"a"})
test.AssertByteEquals(t, h1, h2)
}

View File

@ -196,7 +196,7 @@ var (
errWildcardNotSupported = berrors.MalformedError("Wildcard domain names are not supported")
)
// validDomain checks that a domain isn't:
// ValidDomain checks that a domain isn't:
//
// * empty
// * prefixed with the wildcard label `*.`
@ -210,7 +210,7 @@ var (
// * exactly equal to an IANA registered TLD
//
// It does _not_ check that the domain isn't on any PA blocked lists.
func validDomain(domain string) error {
func ValidDomain(domain string) error {
if domain == "" {
return errEmptyName
}
@ -323,7 +323,7 @@ func ValidEmail(address string) error {
}
splitEmail := strings.SplitN(email.Address, "@", -1)
domain := strings.ToLower(splitEmail[len(splitEmail)-1])
err = validDomain(domain)
err = ValidDomain(domain)
if err != nil {
return berrors.InvalidEmailError(
"contact email %q has invalid domain : %s",
@ -363,7 +363,7 @@ func (pa *AuthorityImpl) willingToIssue(id identifier.ACMEIdentifier) error {
}
domain := id.Value
err := validDomain(domain)
err := ValidDomain(domain)
if err != nil {
return err
}

190
ratelimits/README.md Normal file
View File

@ -0,0 +1,190 @@
# Configuring and Storing Key-Value Rate Limits
## Rate Limit Structure
All rate limits use a token-bucket model. The metaphor is that each limit is
represented by a bucket which holds tokens. Each request removes some number of
tokens from the bucket, or is denied if there aren't enough tokens to remove.
Over time, new tokens are added to the bucket at a steady rate, until the bucket
is full. The _burst_ parameter of a rate limit indicates the maximum capacity of
a bucket: how many tokens can it hold before new ones stop being added.
Therefore, this also indicates how many requests can be made in a single burst
before a full bucket is completely emptied. The _count_ and _period_ parameters
indicate the rate at which new tokens are added to a bucket: every period, count
tokens will be added. Therefore, these also indicate the steady-state rate at
which a client which has exhausted its quota can make requests: one token every
(period / count) duration.
## Default Limit Settings
Each key directly corresponds to a `Name` enumeration as detailed in `//ratelimits/names.go`.
The `Name` enum is used to identify the particular limit. The parameters of a
default limit are the values that will be used for all buckets that do not have
an explicit override (see below).
```yaml
NewRegistrationsPerIPAddress:
burst: 20
count: 20
period: 1s
NewOrdersPerAccount:
burst: 300
count: 300
period: 180m
```
## Override Limit Settings
Each override key represents a specific bucket, consisting of two elements:
_name_ and _id_. The name here refers to the Name of the particular limit, while
the id is a client identifier. The format of the id is dependent on the limit.
For example, the id for 'NewRegistrationsPerIPAddress' is a subscriber IP
address, while the id for 'NewOrdersPerAccount' is the subscriber's registration
ID.
```yaml
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 20
count: 40
period: 1s
NewOrdersPerAccount:12345678:
burst: 300
count: 600
period: 180m
```
The above example overrides the default limits for specific subscribers. In both
cases the count of requests per period are doubled, but the burst capacity is
explicitly configured to match the default rate limit.
### Id Formats in Limit Override Settings
Id formats vary based on the `Name` enumeration. Below are examples for each
format:
#### ipAddress
A valid IPv4 or IPv6 address.
Examples:
- `NewRegistrationsPerIPAddress:10.0.0.1`
- `NewRegistrationsPerIPAddress:2001:0db8:0000:0000:0000:ff00:0042:8329`
#### ipv6RangeCIDR
A valid IPv6 range in CIDR notation with a /48 mask. A /48 range is typically
assigned to a single subscriber.
Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48`
#### regId
The registration ID of the account.
Example: `NewOrdersPerAccount:12345678`
#### regId:domain
A combination of registration ID and domain, formatted 'regId:domain'.
Example: `CertificatesPerDomainPerAccount:12345678:example.com`
#### regId:fqdnSet
A combination of registration ID and a comma-separated list of domain names,
formatted 'regId:fqdnSet'.
Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org`
## Bucket Key Definitions
A bucket key is used to lookup the bucket for a given limit and
subscriber. Bucket keys are formatted similarly to the overrides but with a
slight difference: the limit Names do not carry the string form of each limit.
Instead, they apply the `Name` enum equivalent for every limit.
So, instead of:
```
NewOrdersPerAccount:12345678
```
The corresponding bucket key for regId 12345678 would look like this:
```
6:12345678
```
When loaded from a file, the keys for the default/override limits undergo the
same interning process as the aforementioned subscriber bucket keys. This
eliminates the need for redundant conversions when fetching each
default/override limit.
## How Limits are Applied
Although rate limit buckets are configured in terms of tokens, we do not
actually keep track of the number of tokens in each bucket. Instead, we track
the Theoretical Arrival Time (TAT) at which the bucket will be full again. If
the TAT is in the past, the bucket is full. If the TAT is in the future, some
number of tokens have been spent and the bucket is slowly refilling. If the TAT
is far enough in the future (specifically, more than `burst * (period / count)`)
in the future), then the bucket is completely empty and requests will be denied.
Additional terminology:
- **burst offset** is the duration of time it takes for a bucket to go from
empty to full (`burst * (period / count)`).
- **emission interval** is the interval at which tokens are added to a bucket
(`period / count`). This is also the steady-state rate at which requests can
be made without being denied even once the burst has been exhausted.
- **cost** is the number of tokens removed from a bucket for a single request.
- **cost increment** is the duration of time the TAT is advanced to account
for the cost of the request (`cost * emission interval`).
For the purposes of this example, subscribers originating from a specific IPv4
address are allowed 20 requests to the newFoo endpoint per second, with a
maximum burst of 20 requests at any point-in-time, or:
```yaml
NewFoosPerIPAddress:172.23.45.22:
burst: 20
count: 20
period: 1s
```
A subscriber calls the newFoo endpoint for the first time with an IP address of
172.23.45.22. Here's what happens:
1. The subscriber's IP address is used to generate a bucket key in the form of
'NewFoosPerIPAddress:172.23.45.22'.
2. The request is approved and the 'NewFoosPerIPAddress:172.23.45.22' bucket is
initialized with 19 tokens, as 1 token has been removed to account for the
cost of the current request. To accomplish this, the initial TAT is set to
the current time plus the _cost increment_ (which is 1/20th of a second if we
are limiting to 20 requests per second).
3. Bucket 'NewFoosPerIPAddress:172.23.45.22':
- will reset to full in 50ms (1/20th of a second),
- will allow another newFoo request immediately,
- will allow between 1 and 19 more requests in the next 50ms,
- will reject the 20th request made in the next 50ms,
- and will allow 1 request every 50ms, indefinitely.
The subscriber makes another request 5ms later:
4. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is compared against
the current time and the _burst offset_. The current time is greater than the
TAT minus the cost increment. Therefore, the request is approved.
5. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is advanced by the
cost increment to account for the cost of the request.
The subscriber makes a total of 18 requests over the next 44ms:
6. The current time is less than the TAT at bucket key
'NewFoosPerIPAddress:172.23.45.22' minus the burst offset, thus the request
is rejected.
This mechanism allows for bursts of traffic but also ensures that the average
rate of requests stays within the prescribed limits over time.

110
ratelimits/gcra.go Normal file
View File

@ -0,0 +1,110 @@
package ratelimits
import (
"time"
"github.com/jmhodges/clock"
)
// maybeSpend uses the GCRA algorithm to decide whether to allow a request. It
// returns a Decision struct with the result of the decision and the updated
// TAT. The cost must be 0 or greater and <= the burst capacity of the limit.
func maybeSpend(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
if cost < 0 || cost > rl.Burst {
// The condition above is the union of the conditions checked in Check
// and Spend methods of Limiter. If this panic is reached, it means that
// the caller has introduced a bug.
panic("invalid cost for maybeSpend")
}
nowUnix := clk.Now().UnixNano()
tatUnix := tat.UnixNano()
// If the TAT is in the future, use it as the starting point for the
// calculation. Otherwise, use the current time. This is to prevent the
// bucket from being filled with capacity from the past.
if nowUnix > tatUnix {
tatUnix = nowUnix
}
// Compute the cost increment.
costIncrement := rl.emissionInterval * cost
// Deduct the cost to find the new TAT and residual capacity.
newTAT := tatUnix + costIncrement
difference := nowUnix - (newTAT - rl.burstOffset)
if difference < 0 {
// Too little capacity to satisfy the cost, deny the request.
residual := (nowUnix - (tatUnix - rl.burstOffset)) / rl.emissionInterval
return &Decision{
Allowed: false,
Remaining: residual,
RetryIn: -time.Duration(difference),
ResetIn: time.Duration(tatUnix - nowUnix),
newTAT: time.Unix(0, tatUnix).UTC(),
}
}
// There is enough capacity to satisfy the cost, allow the request.
var retryIn time.Duration
residual := difference / rl.emissionInterval
if difference < costIncrement {
retryIn = time.Duration(costIncrement - difference)
}
return &Decision{
Allowed: true,
Remaining: residual,
RetryIn: retryIn,
ResetIn: time.Duration(newTAT - nowUnix),
newTAT: time.Unix(0, newTAT).UTC(),
}
}
// maybeRefund uses the Generic Cell Rate Algorithm (GCRA) to attempt to refund
// the cost of a request which was previously spent. The refund cost must be 0
// or greater. A cost will only be refunded up to the burst capacity of the
// limit. A partial refund is still considered successful.
func maybeRefund(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
if cost <= 0 || cost > rl.Burst {
// The condition above is checked in the Refund method of Limiter. If
// this panic is reached, it means that the caller has introduced a bug.
panic("invalid cost for maybeRefund")
}
nowUnix := clk.Now().UnixNano()
tatUnix := tat.UnixNano()
// The TAT must be in the future to refund capacity.
if nowUnix > tatUnix {
// The TAT is in the past, therefore the bucket is full.
return &Decision{
Allowed: false,
Remaining: rl.Burst,
RetryIn: time.Duration(0),
ResetIn: time.Duration(0),
newTAT: tat,
}
}
// Compute the refund increment.
refundIncrement := rl.emissionInterval * cost
// Subtract the refund increment from the TAT to find the new TAT.
newTAT := tatUnix - refundIncrement
// Ensure the new TAT is not earlier than now.
if newTAT < nowUnix {
newTAT = nowUnix
}
// Calculate the new capacity.
difference := nowUnix - (newTAT - rl.burstOffset)
residual := difference / rl.emissionInterval
return &Decision{
Allowed: (newTAT != tatUnix),
Remaining: residual,
RetryIn: time.Duration(0),
ResetIn: time.Duration(newTAT - nowUnix),
newTAT: time.Unix(0, newTAT).UTC(),
}
}

221
ratelimits/gcra_test.go Normal file
View File

@ -0,0 +1,221 @@
package ratelimits
import (
"testing"
"time"
"github.com/jmhodges/clock"
"github.com/letsencrypt/boulder/config"
"github.com/letsencrypt/boulder/test"
)
func Test_decide(t *testing.T) {
clk := clock.NewFake()
limit := precomputeLimit(
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
)
// Begin by using 1 of our 10 requests.
d := maybeSpend(clk, limit, clk.Now(), 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(9))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Immediately use another 9 of our remaining requests.
d = maybeSpend(clk, limit, d.newTAT, 9)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
// We should have to wait 1 second before we can use another request but we
// used 9 so we should have to wait 9 seconds to make an identical request.
test.AssertEquals(t, d.RetryIn, time.Second*9)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Our new TAT should be 10 seconds (limit.Burst) in the future.
test.AssertEquals(t, d.newTAT, clk.Now().Add(time.Second*10))
// Let's try using just 1 more request without waiting.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.RetryIn, time.Second)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Let's try being exactly as patient as we're told to be.
clk.Add(d.RetryIn)
d = maybeSpend(clk, limit, d.newTAT, 0)
test.AssertEquals(t, d.Remaining, int64(1))
// We are 1 second in the future, we should have 1 new request.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.RetryIn, time.Second)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Let's try waiting (10 seconds) for our whole bucket to refill.
clk.Add(d.ResetIn)
// We should have 10 new requests. If we use 1 we should have 9 remaining.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(9))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Wait just shy of how long we're told to wait for refilling.
clk.Add(d.ResetIn - time.Millisecond)
// We should still have 9 remaining because we're still 1ms shy of the
// refill time.
d = maybeSpend(clk, limit, d.newTAT, 0)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(9))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Millisecond)
// Spending 0 simply informed us that we still have 9 remaining, let's see
// what we have after waiting 20 hours.
clk.Add(20 * time.Hour)
// C'mon, big money, no whammies, no whammies, STOP!
d = maybeSpend(clk, limit, d.newTAT, 0)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
// Turns out that the most we can accrue is 10 (limit.Burst). Let's empty
// this bucket out so we can try something else.
d = maybeSpend(clk, limit, d.newTAT, 10)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
// We should have to wait 1 second before we can use another request but we
// used 10 so we should have to wait 10 seconds to make an identical
// request.
test.AssertEquals(t, d.RetryIn, time.Second*10)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// If you spend 0 while you have 0 you should get 0.
d = maybeSpend(clk, limit, d.newTAT, 0)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Second*10)
// We don't play by the rules, we spend 1 when we have 0.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.RetryIn, time.Second)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Okay, maybe we should play by the rules if we want to get anywhere.
clk.Add(d.RetryIn)
// Our patience pays off, we should have 1 new request. Let's use it.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.RetryIn, time.Second)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Refill from empty to 5.
clk.Add(d.ResetIn / 2)
// Attempt to spend 7 when we only have 5. We should be denied but the
// decision should reflect a retry of 2 seconds, the time it would take to
// refill from 5 to 7.
d = maybeSpend(clk, limit, d.newTAT, 7)
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(5))
test.AssertEquals(t, d.RetryIn, time.Second*2)
test.AssertEquals(t, d.ResetIn, time.Second*5)
}
func Test_maybeRefund(t *testing.T) {
clk := clock.NewFake()
limit := precomputeLimit(
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
)
// Begin by using 1 of our 10 requests.
d := maybeSpend(clk, limit, clk.Now(), 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(9))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Refund back to 10.
d = maybeRefund(clk, limit, d.newTAT, 1)
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
// Spend 1 more of our 10 requests.
d = maybeSpend(clk, limit, d.newTAT, 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(9))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Wait for our bucket to refill.
clk.Add(d.ResetIn)
// Attempt to refund from 10 to 11.
d = maybeRefund(clk, limit, d.newTAT, 1)
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
// Spend 10 all 10 of our requests.
d = maybeSpend(clk, limit, d.newTAT, 10)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
// We should have to wait 1 second before we can use another request but we
// used 10 so we should have to wait 10 seconds to make an identical
// request.
test.AssertEquals(t, d.RetryIn, time.Second*10)
test.AssertEquals(t, d.ResetIn, time.Second*10)
// Attempt a refund of 10.
d = maybeRefund(clk, limit, d.newTAT, 10)
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
// Wait 11 seconds to catching up to TAT.
clk.Add(11 * time.Second)
// Attempt to refund to 11, then ensure it's still 10.
d = maybeRefund(clk, limit, d.newTAT, 1)
test.Assert(t, !d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
// Spend 5 of our 10 requests, then refund 1.
d = maybeSpend(clk, limit, d.newTAT, 5)
d = maybeRefund(clk, limit, d.newTAT, 1)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(6))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// Wait, a 2.5 seconds to refill to 8.5 requests.
clk.Add(time.Millisecond * 2500)
// Ensure we have 8.5 requests.
d = maybeSpend(clk, limit, d.newTAT, 0)
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(8))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// Check that ResetIn represents the fractional earned request.
test.AssertEquals(t, d.ResetIn, time.Millisecond*1500)
// Refund 2 requests, we should only have 10, not 10.5.
d = maybeRefund(clk, limit, d.newTAT, 2)
test.AssertEquals(t, d.Remaining, int64(10))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
}

160
ratelimits/limit.go Normal file
View File

@ -0,0 +1,160 @@
package ratelimits
import (
"fmt"
"os"
"strings"
"github.com/letsencrypt/boulder/config"
"github.com/letsencrypt/boulder/core"
"github.com/letsencrypt/boulder/strictyaml"
)
type limit struct {
// Burst specifies maximum concurrent allowed requests at any given time. It
// must be greater than zero.
Burst int64
// Count is the number of requests allowed per period. It must be greater
// than zero.
Count int64
// Period is the duration of time in which the count (of requests) is
// allowed. It must be greater than zero.
Period config.Duration
// emissionInterval is the interval, in nanoseconds, at which tokens are
// added to a bucket (period / count). This is also the steady-state rate at
// which requests can be made without being denied even once the burst has
// been exhausted. This is precomputed to avoid doing the same calculation
// on every request.
emissionInterval int64
// burstOffset is the duration of time, in nanoseconds, it takes for a
// bucket to go from empty to full (burst * (period / count)). This is
// precomputed to avoid doing the same calculation on every request.
burstOffset int64
}
func precomputeLimit(l limit) limit {
l.emissionInterval = l.Period.Nanoseconds() / l.Count
l.burstOffset = l.emissionInterval * l.Burst
return l
}
func validateLimit(l limit) error {
if l.Burst <= 0 {
return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst)
}
if l.Count <= 0 {
return fmt.Errorf("invalid count '%d', must be > 0", l.Count)
}
if l.Period.Duration <= 0 {
return fmt.Errorf("invalid period '%s', must be > 0", l.Period)
}
return nil
}
type limits map[string]limit
// loadLimits marshals the YAML file at path into a map of limis.
func loadLimits(path string) (limits, error) {
lm := make(limits)
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
err = strictyaml.Unmarshal(data, &lm)
if err != nil {
return nil, err
}
return lm, nil
}
// parseOverrideNameId is broken out for ease of testing.
func parseOverrideNameId(key string) (Name, string, error) {
if !strings.Contains(key, ":") {
// Avoids a potential panic in strings.SplitN below.
return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
}
nameAndId := strings.SplitN(key, ":", 2)
nameStr := nameAndId[0]
if nameStr == "" {
return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key)
}
name, ok := stringToName[nameStr]
if !ok {
return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, limitNames)
}
id := nameAndId[1]
if id == "" {
return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key)
}
return name, id, nil
}
// loadAndParseOverrideLimits loads override limits from YAML, validates them,
// and parses them into a map of limits keyed by 'Name:id'.
func loadAndParseOverrideLimits(path string) (limits, error) {
fromFile, err := loadLimits(path)
if err != nil {
return nil, err
}
parsed := make(limits, len(fromFile))
for k, v := range fromFile {
err = validateLimit(v)
if err != nil {
return nil, fmt.Errorf("validating override limit %q: %w", k, err)
}
name, id, err := parseOverrideNameId(k)
if err != nil {
return nil, fmt.Errorf("parsing override limit %q: %w", k, err)
}
err = validateIdForName(name, id)
if err != nil {
return nil, fmt.Errorf(
"validating name %s and id %q for override limit %q: %w", nameToString[name], id, k, err)
}
if name == CertificatesPerFQDNSetPerAccount {
// FQDNSet hashes are not a nice thing to ask for in a config file,
// so we allow the user to specify a comma-separated list of FQDNs
// and compute the hash here.
regIdDomains := strings.SplitN(id, ":", 2)
if len(regIdDomains) != 2 {
// Should never happen, the Id format was validated above.
return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k)
}
regId := regIdDomains[0]
domains := strings.Split(regIdDomains[1], ",")
fqdnSet := core.HashNames(domains)
id = fmt.Sprintf("%s:%s", regId, fqdnSet)
}
parsed[bucketKey(name, id)] = precomputeLimit(v)
}
return parsed, nil
}
// loadAndParseDefaultLimits loads default limits from YAML, validates them, and
// parses them into a map of limits keyed by 'Name'.
func loadAndParseDefaultLimits(path string) (limits, error) {
fromFile, err := loadLimits(path)
if err != nil {
return nil, err
}
parsed := make(limits, len(fromFile))
for k, v := range fromFile {
err := validateLimit(v)
if err != nil {
return nil, fmt.Errorf("parsing default limit %q: %w", k, err)
}
name, ok := stringToName[k]
if !ok {
return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames)
}
parsed[nameToEnumString(name)] = precomputeLimit(v)
}
return parsed, nil
}

333
ratelimits/limit_test.go Normal file
View File

@ -0,0 +1,333 @@
package ratelimits
import (
"os"
"testing"
"time"
"github.com/letsencrypt/boulder/config"
"github.com/letsencrypt/boulder/core"
"github.com/letsencrypt/boulder/test"
)
func Test_parseOverrideNameId(t *testing.T) {
newRegistrationsPerIPAddressStr := nameToString[NewRegistrationsPerIPAddress]
newRegistrationsPerIPv6RangeStr := nameToString[NewRegistrationsPerIPv6Range]
// 'enum:ipv4'
// Valid IPv4 address.
name, id, err := parseOverrideNameId(newRegistrationsPerIPAddressStr + ":10.0.0.1")
test.AssertNotError(t, err, "should not error")
test.AssertEquals(t, name, NewRegistrationsPerIPAddress)
test.AssertEquals(t, id, "10.0.0.1")
// 'enum:ipv6range'
// Valid IPv6 address range.
name, id, err = parseOverrideNameId(newRegistrationsPerIPv6RangeStr + ":2001:0db8:0000::/48")
test.AssertNotError(t, err, "should not error")
test.AssertEquals(t, name, NewRegistrationsPerIPv6Range)
test.AssertEquals(t, id, "2001:0db8:0000::/48")
// Missing colon (this should never happen but we should avoid panicking).
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + "10.0.0.1")
test.AssertError(t, err, "missing colon")
// Empty string.
_, _, err = parseOverrideNameId("")
test.AssertError(t, err, "empty string")
// Only a colon.
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + ":")
test.AssertError(t, err, "only a colon")
// Invalid enum.
_, _, err = parseOverrideNameId("lol:noexist")
test.AssertError(t, err, "invalid enum")
}
func Test_validateLimit(t *testing.T) {
err := validateLimit(limit{Burst: 1, Count: 1, Period: config.Duration{Duration: time.Second}})
test.AssertNotError(t, err, "valid limit")
// All of the following are invalid.
for _, l := range []limit{
{Burst: 0, Count: 1, Period: config.Duration{Duration: time.Second}},
{Burst: 1, Count: 0, Period: config.Duration{Duration: time.Second}},
{Burst: 1, Count: 1, Period: config.Duration{Duration: 0}},
} {
err = validateLimit(l)
test.AssertError(t, err, "limit should be invalid")
}
}
func Test_validateIdForName(t *testing.T) {
// 'enum:ipAddress'
// Valid IPv4 address.
err := validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.1")
test.AssertNotError(t, err, "valid ipv4 address")
// 'enum:ipAddress'
// Valid IPv6 address.
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334")
test.AssertNotError(t, err, "valid ipv6 address")
// 'enum:ipv6rangeCIDR'
// Valid IPv6 address range.
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48")
test.AssertNotError(t, err, "should not error")
// 'enum:regId'
// Valid regId.
err = validateIdForName(NewOrdersPerAccount, "1234567890")
test.AssertNotError(t, err, "valid regId")
// 'enum:regId:domain'
// Valid regId and domain.
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com")
test.AssertNotError(t, err, "valid regId and domain")
// 'enum:regId:fqdnSet'
// Valid regId and FQDN set containing a single domain.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com")
test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain")
// 'enum:regId:fqdnSet'
// Valid regId and FQDN set containing multiple domains.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org")
test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains")
// Empty string.
err = validateIdForName(NewRegistrationsPerIPAddress, "")
test.AssertError(t, err, "Id is an empty string")
// One space.
err = validateIdForName(NewRegistrationsPerIPAddress, " ")
test.AssertError(t, err, "Id is a single space")
// Invalid IPv4 address.
err = validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.9000")
test.AssertError(t, err, "invalid IPv4 address")
// Invalid IPv6 address.
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334:9000")
test.AssertError(t, err, "invalid IPv6 address")
// Invalid IPv6 CIDR range.
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/128")
test.AssertError(t, err, "invalid IPv6 CIDR range")
// Invalid IPv6 CIDR.
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48/48")
test.AssertError(t, err, "invalid IPv6 CIDR")
// IPv4 CIDR when we expect IPv6 CIDR range.
err = validateIdForName(NewRegistrationsPerIPv6Range, "10.0.0.0/16")
test.AssertError(t, err, "ipv4 cidr when we expect ipv6 cidr range")
// Invalid regId.
err = validateIdForName(NewOrdersPerAccount, "lol")
test.AssertError(t, err, "invalid regId")
// Invalid regId with good domain.
err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com")
test.AssertError(t, err, "invalid regId with good domain")
// Valid regId with bad domain.
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol")
test.AssertError(t, err, "valid regId with bad domain")
// Empty regId with good domain.
err = validateIdForName(CertificatesPerDomainPerAccount, ":lol")
test.AssertError(t, err, "valid regId with bad domain")
// Valid regId with empty domain.
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:")
test.AssertError(t, err, "valid regId with empty domain")
// Empty regId with empty domain, no separator.
err = validateIdForName(CertificatesPerDomainPerAccount, "")
test.AssertError(t, err, "empty regId with empty domain, no separator")
// Instead of anything we would expect, we get lol.
err = validateIdForName(CertificatesPerDomainPerAccount, "lol")
test.AssertError(t, err, "instead of anything we would expect, just lol")
// Valid regId with good domain and a secret third separator.
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol")
test.AssertError(t, err, "valid regId with good domain and a secret third separator")
// Valid regId with bad FQDN set.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99")
test.AssertError(t, err, "valid regId with bad FQDN set")
// Bad regId with good FQDN set.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org")
test.AssertError(t, err, "bad regId with good FQDN set")
// Empty regId with good FQDN set.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org")
test.AssertError(t, err, "empty regId with good FQDN set")
// Good regId with empty FQDN set.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:")
test.AssertError(t, err, "good regId with empty FQDN set")
// Empty regId with empty FQDN set, no separator.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "")
test.AssertError(t, err, "empty regId with empty FQDN set, no separator")
// Instead of anything we would expect, just lol.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol")
test.AssertError(t, err, "instead of anything we would expect, just lol")
// Valid regId with good FQDN set and a secret third separator.
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol")
test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator")
}
func Test_loadAndParseOverrideLimits(t *testing.T) {
newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
// Load a single valid override limit with Id formatted as 'enum:RegId'.
l, err := loadAndParseOverrideLimits("testdata/working_override.yml")
test.AssertNotError(t, err, "valid single override limit")
expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
test.AssertEquals(t, l[expectKey].Burst, int64(40))
test.AssertEquals(t, l[expectKey].Count, int64(40))
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
// Load single valid override limit with Id formatted as 'regId:domain'.
l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml")
test.AssertNotError(t, err, "valid single override limit with Id of regId:domain")
expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com"
test.AssertEquals(t, l[expectKey].Burst, int64(40))
test.AssertEquals(t, l[expectKey].Count, int64(40))
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
// Load multiple valid override limits with 'enum:RegId' Ids.
l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml")
expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
test.AssertNotError(t, err, "multiple valid override limits")
test.AssertEquals(t, l[expectKey1].Burst, int64(40))
test.AssertEquals(t, l[expectKey1].Count, int64(40))
test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second)
expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48"
test.AssertEquals(t, l[expectKey2].Burst, int64(50))
test.AssertEquals(t, l[expectKey2].Count, int64(50))
test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2)
// Load multiple valid override limits with 'regID:fqdnSet' Ids as follows:
// - CertificatesPerFQDNSetPerAccount:12345678:example.com
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org
firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"}))
secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"}))
thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"}))
firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash
secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash
thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash
l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml")
test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets")
test.AssertEquals(t, l[firstEntryKey].Burst, int64(40))
test.AssertEquals(t, l[firstEntryKey].Count, int64(40))
test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second)
test.AssertEquals(t, l[secondEntryKey].Burst, int64(50))
test.AssertEquals(t, l[secondEntryKey].Count, int64(50))
test.AssertEquals(t, l[secondEntryKey].Period.Duration, time.Second*2)
test.AssertEquals(t, l[thirdEntryKey].Burst, int64(60))
test.AssertEquals(t, l[thirdEntryKey].Count, int64(60))
test.AssertEquals(t, l[thirdEntryKey].Period.Duration, time.Second*3)
// Path is empty string.
_, err = loadAndParseOverrideLimits("")
test.AssertError(t, err, "path is empty string")
test.Assert(t, os.IsNotExist(err), "path is empty string")
// Path to file which does not exist.
_, err = loadAndParseOverrideLimits("testdata/file_does_not_exist.yml")
test.AssertError(t, err, "a file that does not exist ")
test.Assert(t, os.IsNotExist(err), "test file should not exist")
// Burst cannot be 0.
_, err = loadAndParseOverrideLimits("testdata/busted_override_burst_0.yml")
test.AssertError(t, err, "single override limit with burst=0")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Id cannot be empty.
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_id.yml")
test.AssertError(t, err, "single override limit with empty id")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Name cannot be empty.
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_name.yml")
test.AssertError(t, err, "single override limit with empty name")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Name must be a string representation of a valid Name enumeration.
_, err = loadAndParseOverrideLimits("testdata/busted_override_invalid_name.yml")
test.AssertError(t, err, "single override limit with invalid name")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Multiple entries, second entry has a bad name.
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_second_entry_bad_name.yml")
test.AssertError(t, err, "multiple override limits, second entry is bad")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Multiple entries, third entry has id of "lol", instead of an IPv4 address.
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_third_entry_bad_id.yml")
test.AssertError(t, err, "multiple override limits, third entry has bad Id value")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
}
func Test_loadAndParseDefaultLimits(t *testing.T) {
newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
// Load a single valid default limit.
l, err := loadAndParseDefaultLimits("testdata/working_default.yml")
test.AssertNotError(t, err, "valid single default limit")
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
// Load multiple valid default limits.
l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml")
test.AssertNotError(t, err, "multiple valid default limits")
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30))
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30))
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2)
// Path is empty string.
_, err = loadAndParseDefaultLimits("")
test.AssertError(t, err, "path is empty string")
test.Assert(t, os.IsNotExist(err), "path is empty string")
// Path to file which does not exist.
_, err = loadAndParseDefaultLimits("testdata/file_does_not_exist.yml")
test.AssertError(t, err, "a file that does not exist")
test.Assert(t, os.IsNotExist(err), "test file should not exist")
// Burst cannot be 0.
_, err = loadAndParseDefaultLimits("testdata/busted_default_burst_0.yml")
test.AssertError(t, err, "single default limit with burst=0")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Name cannot be empty.
_, err = loadAndParseDefaultLimits("testdata/busted_default_empty_name.yml")
test.AssertError(t, err, "single default limit with empty name")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Name must be a string representation of a valid Name enumeration.
_, err = loadAndParseDefaultLimits("testdata/busted_default_invalid_name.yml")
test.AssertError(t, err, "single default limit with invalid name")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
// Multiple entries, second entry has a bad name.
_, err = loadAndParseDefaultLimits("testdata/busted_defaults_second_entry_bad_name.yml")
test.AssertError(t, err, "multiple default limits, one is bad")
test.Assert(t, !os.IsNotExist(err), "test file should exist")
}

234
ratelimits/limiter.go Normal file
View File

@ -0,0 +1,234 @@
package ratelimits
import (
"errors"
"fmt"
"time"
"github.com/jmhodges/clock"
)
// ErrInvalidCost indicates that the cost specified was <= 0.
var ErrInvalidCost = fmt.Errorf("invalid cost, must be > 0")
// ErrInvalidCostForCheck indicates that the check cost specified was < 0.
var ErrInvalidCostForCheck = fmt.Errorf("invalid check cost, must be >= 0")
// ErrInvalidCostOverLimit indicates that the cost specified was > limit.Burst.
var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst")
// ErrBucketAlreadyFull indicates that the bucket already has reached its
// maximum capacity.
var ErrBucketAlreadyFull = fmt.Errorf("bucket already full")
// Limiter provides a high-level interface for rate limiting requests by
// utilizing a leaky bucket-style approach.
type Limiter struct {
// defaults stores default limits by 'name'.
defaults limits
// overrides stores override limits by 'name:id'.
overrides limits
// source is used to store buckets. It must be safe for concurrent use.
source source
clk clock.Clock
}
// NewLimiter returns a new *Limiter. The provided source must be safe for
// concurrent use. The defaults and overrides paths are expected to be paths to
// YAML files that contain the default and override limits, respectively. The
// overrides file is optional, all other arguments are required.
func NewLimiter(clk clock.Clock, source source, defaults, overrides string) (*Limiter, error) {
limiter := &Limiter{source: source, clk: clk}
var err error
limiter.defaults, err = loadAndParseDefaultLimits(defaults)
if err != nil {
return nil, err
}
if overrides == "" {
// No overrides specified, initialize an empty map.
limiter.overrides = make(limits)
return limiter, nil
}
limiter.overrides, err = loadAndParseOverrideLimits(overrides)
if err != nil {
return nil, err
}
return limiter, nil
}
type Decision struct {
// Allowed is true if the bucket possessed enough capacity to allow the
// request given the cost.
Allowed bool
// Remaining is the number of requests the client is allowed to make before
// they're rate limited.
Remaining int64
// RetryIn is the duration the client MUST wait before they're allowed to
// make a request.
RetryIn time.Duration
// ResetIn is the duration the bucket will take to refill to its maximum
// capacity, assuming no further requests are made.
ResetIn time.Duration
// newTAT indicates the time at which the bucket will be full. It is the
// theoretical arrival time (TAT) of next request. It must be no more than
// (burst * (period / count)) in the future at any single point in time.
newTAT time.Time
}
// Check returns a *Decision that indicates whether there's enough capacity to
// allow the request, given the cost, for the specified limit Name and client
// id. However, it DOES NOT deduct the cost of the request from the bucket's
// capacity. Hence, the returned *Decision represents the hypothetical state of
// the bucket if the cost WERE to be deducted. The returned *Decision will
// always include the number of remaining requests in the bucket, the required
// wait time before the client can make another request, and the time until the
// bucket refills to its maximum capacity (resets). If no bucket exists for the
// given limit Name and client id, a new one will be created WITHOUT the
// request's cost deducted from its initial capacity.
func (l *Limiter) Check(name Name, id string, cost int64) (*Decision, error) {
if cost < 0 {
return nil, ErrInvalidCostForCheck
}
limit, err := l.getLimit(name, id)
if err != nil {
return nil, err
}
if cost > limit.Burst {
return nil, ErrInvalidCostOverLimit
}
tat, err := l.source.Get(bucketKey(name, id))
if err != nil {
if !errors.Is(err, ErrBucketNotFound) {
return nil, err
}
// First request from this client. The cost is not deducted from the
// initial capacity because this is only a check.
d, err := l.initialize(limit, name, id, 0)
if err != nil {
return nil, err
}
return maybeSpend(l.clk, limit, d.newTAT, cost), nil
}
return maybeSpend(l.clk, limit, tat, cost), nil
}
// Spend returns a *Decision that indicates if enough capacity was available to
// process the request, given the cost, for the specified limit Name and client
// id. If capacity existed, the cost of the request HAS been deducted from the
// bucket's capacity, otherwise no cost was deducted. The returned *Decision
// will always include the number of remaining requests in the bucket, the
// required wait time before the client can make another request, and the time
// until the bucket refills to its maximum capacity (resets). If no bucket
// exists for the given limit Name and client id, a new one will be created WITH
// the request's cost deducted from its initial capacity.
func (l *Limiter) Spend(name Name, id string, cost int64) (*Decision, error) {
if cost <= 0 {
return nil, ErrInvalidCost
}
limit, err := l.getLimit(name, id)
if err != nil {
return nil, err
}
if cost > limit.Burst {
return nil, ErrInvalidCostOverLimit
}
tat, err := l.source.Get(bucketKey(name, id))
if err != nil {
if errors.Is(err, ErrBucketNotFound) {
// First request from this client.
return l.initialize(limit, name, id, cost)
}
return nil, err
}
d := maybeSpend(l.clk, limit, tat, cost)
if !d.Allowed {
return d, nil
}
return d, l.source.Set(bucketKey(name, id), d.newTAT)
}
// Refund attempts to refund the cost to the bucket identified by limit name and
// client id. The returned *Decision indicates whether the refund was successful
// or not. If the refund was successful, the cost of the request was added back
// to the bucket's capacity. If the refund is not possible (i.e., the bucket is
// already full or the refund amount is invalid), no cost is refunded.
//
// Note: The amount refunded cannot cause the bucket to exceed its maximum
// capacity. However, partial refunds are allowed and are considered successful.
// For instance, if a bucket has a maximum capacity of 10 and currently has 5
// requests remaining, a refund request of 7 will result in the bucket reaching
// its maximum capacity of 10, not 12.
func (l *Limiter) Refund(name Name, id string, cost int64) (*Decision, error) {
if cost <= 0 {
return nil, ErrInvalidCost
}
limit, err := l.getLimit(name, id)
if err != nil {
return nil, err
}
tat, err := l.source.Get(bucketKey(name, id))
if err != nil {
return nil, err
}
d := maybeRefund(l.clk, limit, tat, cost)
if !d.Allowed {
return d, ErrBucketAlreadyFull
}
return d, l.source.Set(bucketKey(name, id), d.newTAT)
}
// Reset resets the specified bucket.
func (l *Limiter) Reset(name Name, id string) error {
return l.source.Delete(bucketKey(name, id))
}
// initialize creates a new bucket, specified by limit name and id, with the
// cost of the request factored into the initial state.
func (l *Limiter) initialize(rl limit, name Name, id string, cost int64) (*Decision, error) {
d := maybeSpend(l.clk, rl, l.clk.Now(), cost)
err := l.source.Set(bucketKey(name, id), d.newTAT)
if err != nil {
return nil, err
}
return d, nil
}
// GetLimit returns the limit for the specified by name and id, name is
// required, id is optional. If id is left unspecified, the default limit for
// the limit specified by name is returned.
func (l *Limiter) getLimit(name Name, id string) (limit, error) {
if id != "" {
// Check for override.
ol, ok := l.overrides[bucketKey(name, id)]
if ok {
return ol, nil
}
}
dl, ok := l.defaults[nameToEnumString(name)]
if ok {
return dl, nil
}
return limit{}, fmt.Errorf("limit %q does not exist", name)
}

315
ratelimits/limiter_test.go Normal file
View File

@ -0,0 +1,315 @@
package ratelimits
import (
"testing"
"time"
"github.com/jmhodges/clock"
"github.com/letsencrypt/boulder/test"
)
const (
tenZeroZeroOne = "10.0.0.1"
tenZeroZeroTwo = "10.0.0.2"
)
// newTestLimiter makes a new limiter with the following configuration:
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
func newTestLimiter(t *testing.T) (*Limiter, clock.FakeClock) {
clk := clock.NewFake()
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "")
test.AssertNotError(t, err, "should not error")
return l, clk
}
// newTestLimiterWithOverrides makes a new limiter with the following
// configuration:
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
// - 'NewRegistrationsPerIPAddress:10.0.0.2' burst: 40 count: 40 period: 1s
func newTestLimiterWithOverrides(t *testing.T) (*Limiter, clock.FakeClock) {
clk := clock.NewFake()
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "testdata/working_override.yml")
test.AssertNotError(t, err, "should not error")
return l, clk
}
func Test_Limiter_initialization_via_Check_and_Spend(t *testing.T) {
l, _ := newTestLimiter(t)
// Check on an empty bucket should initialize it and return the theoretical
// next state of that bucket if the cost were spent.
d, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(19))
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
// 1000/20 = 50 milliseconds per request.
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// However, that cost should not be spent yet, a 0 cost check should tell us
// that we actually have 20 remaining.
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(20))
test.AssertEquals(t, d.ResetIn, time.Duration(0))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// Reset our bucket.
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
test.AssertNotError(t, err, "should not error")
// Similar to above, but we'll use Spend() instead of Check() to initialize
// the bucket. Spend should return the same result as Check.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(19))
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
// 1000/20 = 50 milliseconds per request.
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// However, that cost should not be spent yet, a 0 cost check should tell us
// that we actually have 19 remaining.
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(19))
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
// 1000/20 = 50 milliseconds per request.
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
test.AssertEquals(t, d.RetryIn, time.Duration(0))
}
func Test_Limiter_Refund_and_Spend_cost_err(t *testing.T) {
l, _ := newTestLimiter(t)
// Spend a cost of 0, which should fail.
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
test.AssertErrorIs(t, err, ErrInvalidCost)
// Spend a negative cost, which should fail.
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
test.AssertErrorIs(t, err, ErrInvalidCost)
// Refund a cost of 0, which should fail.
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
test.AssertErrorIs(t, err, ErrInvalidCost)
// Refund a negative cost, which should fail.
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
test.AssertErrorIs(t, err, ErrInvalidCost)
}
func Test_Limiter_with_bad_limits_path(t *testing.T) {
_, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "")
test.AssertError(t, err, "should error")
_, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml")
test.AssertError(t, err, "should error")
}
func Test_Limiter_Check_bad_cost(t *testing.T) {
l, _ := newTestLimiter(t)
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
test.AssertErrorIs(t, err, ErrInvalidCostForCheck)
}
func Test_Limiter_Check_limit_no_exist(t *testing.T) {
l, _ := newTestLimiter(t)
_, err := l.Check(Name(9999), tenZeroZeroOne, 1)
test.AssertError(t, err, "should error")
}
func Test_Limiter_getLimit_no_exist(t *testing.T) {
l, _ := newTestLimiter(t)
_, err := l.getLimit(Name(9999), "")
test.AssertError(t, err, "should error")
}
func Test_Limiter_with_defaults(t *testing.T) {
l, clk := newTestLimiter(t)
// Attempt to spend 21 requests (a cost > the limit burst capacity), this
// should fail with a specific error.
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 21)
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
// Attempt to spend all 20 requests, this should succeed.
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Attempting to spend 1 more, this should fail.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
// 1000/20 = 50 milliseconds per request.
test.AssertEquals(t, d.RetryIn, time.Millisecond*50)
// Wait 50 milliseconds and try again.
clk.Add(d.RetryIn)
// We should be allowed to spend 1 more request.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Wait 1 second for a full bucket reset.
clk.Add(d.ResetIn)
// Quickly spend 20 requests in a row.
for i := 0; i < 20; i++ {
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(19-i))
}
// Attempting to spend 1 more, this should fail.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
}
func Test_Limiter_with_limit_overrides(t *testing.T) {
l, clk := newTestLimiterWithOverrides(t)
// Attempt to check a spend of 41 requests (a cost > the limit burst
// capacity), this should fail with a specific error.
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
// Attempt to spend 41 requests (a cost > the limit burst capacity), this
// should fail with a specific error.
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
// Attempt to spend all 40 requests, this should succeed.
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
// Attempting to spend 1 more, this should fail.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
// 1000/40 = 25 milliseconds per request.
test.AssertEquals(t, d.RetryIn, time.Millisecond*25)
// Wait 50 milliseconds and try again.
clk.Add(d.RetryIn)
// We should be allowed to spend 1 more request.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Wait 1 second for a full bucket reset.
clk.Add(d.ResetIn)
// Quickly spend 40 requests in a row.
for i := 0; i < 40; i++ {
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(39-i))
}
// Attempting to spend 1 more, this should fail.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, !d.Allowed, "should not be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
}
func Test_Limiter_with_new_clients(t *testing.T) {
l, _ := newTestLimiter(t)
// Attempt to spend all 20 requests, this should succeed.
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Another new client, spend 1 and check our remaining.
d, err = l.Spend(NewRegistrationsPerIPAddress, "10.0.0.100", 1)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(19))
test.AssertEquals(t, d.RetryIn, time.Duration(0))
// 1 second == 1000 milliseconds and 1000/20 = 50 milliseconds per request.
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
}
func Test_Limiter_Refund_and_Reset(t *testing.T) {
l, clk := newTestLimiter(t)
// Attempt to spend all 20 requests, this should succeed.
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Refund 10 requests.
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
test.AssertNotError(t, err, "should not error")
test.AssertEquals(t, d.Remaining, int64(10))
// Spend 10 requests, this should succeed.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
test.AssertNotError(t, err, "should not error")
// Attempt to spend 20 more requests, this should succeed.
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.Allowed, "should be allowed")
test.AssertEquals(t, d.Remaining, int64(0))
test.AssertEquals(t, d.ResetIn, time.Second)
// Reset to full.
clk.Add(d.ResetIn)
// Refund 1 requests above our limit, this should fail.
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertErrorIs(t, err, ErrBucketAlreadyFull)
test.AssertEquals(t, d.Remaining, int64(20))
}
func Test_Limiter_Check_Spend_parity(t *testing.T) {
il, _ := newTestLimiter(t)
jl, _ := newTestLimiter(t)
i, err := il.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
j, err := jl.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
test.AssertNotError(t, err, "should not error")
test.AssertDeepEquals(t, i.Remaining, j.Remaining)
}

202
ratelimits/names.go Normal file
View File

@ -0,0 +1,202 @@
package ratelimits
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/letsencrypt/boulder/policy"
)
// Name is an enumeration of all rate limit names. It is used to intern rate
// limit names as strings and to provide a type-safe way to refer to rate
// limits.
//
// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString'
// mapping and idValidForName function below.
type Name int
const (
// Unknown is the zero value of Name and is used to indicate an unknown
// limit name.
Unknown Name = iota
// NewRegistrationsPerIPAddress uses bucket key 'enum:ipAddress'.
NewRegistrationsPerIPAddress
// NewRegistrationsPerIPv6Range uses bucket key 'enum:ipv6rangeCIDR'. The
// address range must be a /48.
NewRegistrationsPerIPv6Range
// NewOrdersPerAccount uses bucket key 'enum:regId'.
NewOrdersPerAccount
// FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId
// is the registration id of the account.
FailedAuthorizationsPerAccount
// CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain',
// where name is the a name in a certificate issued to the account matching
// regId.
CertificatesPerDomainPerAccount
// CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet',
// where nameSet is a set of names in a certificate issued to the account
// matching regId.
CertificatesPerFQDNSetPerAccount
)
// nameToString is a map of Name values to string names.
var nameToString = map[Name]string{
Unknown: "Unknown",
NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress",
NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range",
NewOrdersPerAccount: "NewOrdersPerAccount",
FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount",
CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount",
CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount",
}
// validIPAddress validates that the provided string is a valid IP address.
func validIPAddress(id string) error {
ip := net.ParseIP(id)
if ip == nil {
return fmt.Errorf("invalid IP address, %q must be an IP address", id)
}
return nil
}
// validIPv6RangeCIDR validates that the provided string is formatted is an IPv6
// CIDR range with a /48 mask.
func validIPv6RangeCIDR(id string) error {
_, ipNet, err := net.ParseCIDR(id)
if err != nil {
return fmt.Errorf(
"invalid CIDR, %q must be an IPv6 CIDR range", id)
}
ones, _ := ipNet.Mask.Size()
if ones != 48 {
// This also catches the case where the range is an IPv4 CIDR, since an
// IPv4 CIDR can't have a /48 subnet mask - the maximum is /32.
return fmt.Errorf(
"invalid CIDR, %q must be /48", id)
}
return nil
}
// validateRegId validates that the provided string is a valid ACME regId.
func validateRegId(id string) error {
_, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return fmt.Errorf("invalid regId, %q must be an ACME registration Id", id)
}
return nil
}
// validateRegIdDomain validates that the provided string is formatted
// 'regId:domain', where regId is an ACME registration Id and domain is a single
// domain name.
func validateRegIdDomain(id string) error {
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
return fmt.Errorf(
"invalid regId:domain, %q must be formatted 'regId:domain'", id)
}
if validateRegId(parts[0]) != nil {
return fmt.Errorf(
"invalid regId, %q must be formatted 'regId:domain'", id)
}
if policy.ValidDomain(parts[1]) != nil {
return fmt.Errorf(
"invalid domain, %q must be formatted 'regId:domain'", id)
}
return nil
}
// validateRegIdFQDNSet validates that the provided string is formatted
// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a
// comma-separated list of domain names.
func validateRegIdFQDNSet(id string) error {
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
return fmt.Errorf(
"invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
}
if validateRegId(parts[0]) != nil {
return fmt.Errorf(
"invalid regId, %q must be formatted 'regId:fqdnSet'", id)
}
domains := strings.Split(parts[1], ",")
if len(domains) == 0 {
return fmt.Errorf(
"invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
}
for _, domain := range domains {
if policy.ValidDomain(domain) != nil {
return fmt.Errorf(
"invalid domain, %q must be formatted 'regId:fqdnSet'", id)
}
}
return nil
}
func validateIdForName(name Name, id string) error {
switch name {
case NewRegistrationsPerIPAddress:
// 'enum:ipaddress'
return validIPAddress(id)
case NewRegistrationsPerIPv6Range:
// 'enum:ipv6rangeCIDR'
return validIPv6RangeCIDR(id)
case NewOrdersPerAccount, FailedAuthorizationsPerAccount:
// 'enum:regId'
return validateRegId(id)
case CertificatesPerDomainPerAccount:
// 'enum:regId:domain'
return validateRegIdDomain(id)
case CertificatesPerFQDNSetPerAccount:
// 'enum:regId:fqdnSet'
return validateRegIdFQDNSet(id)
case Unknown:
fallthrough
default:
// This should never happen.
return fmt.Errorf("unknown limit enum %q", name)
}
}
// stringToName is a map of string names to Name values.
var stringToName = func() map[string]Name {
m := make(map[string]Name, len(nameToString))
for k, v := range nameToString {
m[v] = k
}
return m
}()
// limitNames is a slice of all rate limit names.
var limitNames = func() []string {
names := make([]string, len(nameToString))
for _, v := range nameToString {
names = append(names, v)
}
return names
}()
// nameToEnumString converts the integer value of the Name enumeration to its
// string representation.
func nameToEnumString(s Name) string {
return strconv.Itoa(int(s))
}
// bucketKey returns the key used to store a rate limit bucket.
func bucketKey(name Name, id string) string {
return nameToEnumString(name) + ":" + id
}

57
ratelimits/source.go Normal file
View File

@ -0,0 +1,57 @@
package ratelimits
import (
"fmt"
"sync"
"time"
)
// ErrBucketNotFound indicates that the bucket was not found.
var ErrBucketNotFound = fmt.Errorf("bucket not found")
// source is an interface for creating and modifying TATs.
type source interface {
// Set stores the TAT at the specified bucketKey ('name:id').
Set(bucketKey string, tat time.Time) error
// Get retrieves the TAT at the specified bucketKey ('name:id').
Get(bucketKey string) (time.Time, error)
// Delete deletes the TAT at the specified bucketKey ('name:id').
Delete(bucketKey string) error
}
// inmem is an in-memory implementation of the source interface used for
// testing.
type inmem struct {
sync.RWMutex
m map[string]time.Time
}
func newInmem() *inmem {
return &inmem{m: make(map[string]time.Time)}
}
func (in *inmem) Set(bucketKey string, tat time.Time) error {
in.Lock()
defer in.Unlock()
in.m[bucketKey] = tat
return nil
}
func (in *inmem) Get(bucketKey string) (time.Time, error) {
in.RLock()
defer in.RUnlock()
tat, ok := in.m[bucketKey]
if !ok {
return time.Time{}, ErrBucketNotFound
}
return tat, nil
}
func (in *inmem) Delete(bucketKey string) error {
in.Lock()
defer in.Unlock()
delete(in.m, bucketKey)
return nil
}

View File

@ -0,0 +1,4 @@
NewRegistrationsPerIPAddress:
burst: 0
count: 20
period: 1s

View File

@ -0,0 +1,4 @@
"":
burst: 20
count: 20
period: 1s

View File

@ -0,0 +1,4 @@
UsageRequestsPerIPv10Address:
burst: 20
count: 20
period: 1s

View File

@ -0,0 +1,8 @@
NewRegistrationsPerIPAddress:
burst: 20
count: 20
period: 1s
UsageRequestsPerIPv10Address:
burst: 20
count: 20
period: 1s

View File

@ -0,0 +1,4 @@
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 0
count: 40
period: 1s

View File

@ -0,0 +1,4 @@
"UsageRequestsPerIPv10Address:":
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,4 @@
":10.0.0.2":
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,4 @@
UsageRequestsPerIPv10Address:10.0.0.2:
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,8 @@
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 40
count: 40
period: 1s
UsageRequestsPerIPv10Address:10.0.0.5:
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,12 @@
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 40
count: 40
period: 1s
NewRegistrationsPerIPAddress:10.0.0.5:
burst: 40
count: 40
period: 1s
NewRegistrationsPerIPAddress:lol:
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,4 @@
NewRegistrationsPerIPAddress:
burst: 20
count: 20
period: 1s

View File

@ -0,0 +1,8 @@
NewRegistrationsPerIPAddress:
burst: 20
count: 20
period: 1s
NewRegistrationsPerIPv6Range:
burst: 30
count: 30
period: 2s

View File

@ -0,0 +1,4 @@
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,4 @@
CertificatesPerDomainPerAccount:12345678:example.com:
burst: 40
count: 40
period: 1s

View File

@ -0,0 +1,8 @@
NewRegistrationsPerIPAddress:10.0.0.2:
burst: 40
count: 40
period: 1s
NewRegistrationsPerIPv6Range:2001:0db8:0000::/48:
burst: 50
count: 50
period: 2s

View File

@ -0,0 +1,12 @@
CertificatesPerFQDNSetPerAccount:12345678:example.com:
burst: 40
count: 40
period: 1s
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net:
burst: 50
count: 50
period: 2s
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org:
burst: 60
count: 60
period: 3s

View File

@ -12,7 +12,6 @@ import (
"net"
"net/url"
"strconv"
"strings"
"time"
"golang.org/x/exp/slices"
@ -873,14 +872,6 @@ type crlEntryModel struct {
RevokedDate time.Time `db:"revokedDate"`
}
// HashNames returns a hash of the names requested. This is intended for use
// when interacting with the orderFqdnSets table.
func HashNames(names []string) []byte {
names = core.UniqueLowerNames(names)
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
return hash[:]
}
// orderFQDNSet contains the SHA256 hash of the lowercased, comma joined names
// from a new-order request, along with the corresponding orderID, the
// registration ID, and the order expiry. This is used to find
@ -895,7 +886,7 @@ type orderFQDNSet struct {
func addFQDNSet(ctx context.Context, db db.Inserter, names []string, serial string, issued time.Time, expires time.Time) error {
return db.Insert(ctx, &core.FQDNSet{
SetHash: HashNames(names),
SetHash: core.HashNames(names),
Serial: serial,
Issued: issued,
Expires: expires,
@ -914,7 +905,7 @@ func addOrderFQDNSet(
regID int64,
expires time.Time) error {
return db.Insert(ctx, &orderFQDNSet{
SetHash: HashNames(names),
SetHash: core.HashNames(names),
OrderID: orderID,
RegistrationID: regID,
Expires: expires,

View File

@ -2893,33 +2893,6 @@ func TestBlockedKeyRevokedBy(t *testing.T) {
test.AssertNotError(t, err, "AddBlockedKey failed")
}
func TestHashNames(t *testing.T) {
// Test that it is deterministic
h1 := HashNames([]string{"a"})
h2 := HashNames([]string{"a"})
test.AssertByteEquals(t, h1, h2)
// Test that it differentiates
h1 = HashNames([]string{"a"})
h2 = HashNames([]string{"b"})
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
// Test that it is not subject to ordering
h1 = HashNames([]string{"a", "b"})
h2 = HashNames([]string{"b", "a"})
test.AssertByteEquals(t, h1, h2)
// Test that it is not subject to case
h1 = HashNames([]string{"a", "b"})
h2 = HashNames([]string{"A", "B"})
test.AssertByteEquals(t, h1, h2)
// Test that it is not subject to duplication
h1 = HashNames([]string{"a", "a"})
h2 = HashNames([]string{"a"})
test.AssertByteEquals(t, h1, h2)
}
func TestIncidentsForSerial(t *testing.T) {
sa, _, cleanUp := initSA(t)
defer cleanUp()

View File

@ -516,7 +516,7 @@ func (ssa *SQLStorageAuthorityRO) CountFQDNSets(ctx context.Context, req *sapb.C
`SELECT COUNT(*) FROM fqdnSets
WHERE setHash = ?
AND issued > ?`,
HashNames(req.Domains),
core.HashNames(req.Domains),
ssa.clk.Now().Add(-time.Duration(req.Window)),
)
return &sapb.Count{Count: count}, err
@ -544,7 +544,7 @@ func (ssa *SQLStorageAuthorityRO) FQDNSetTimestampsForWindow(ctx context.Context
WHERE setHash = ?
AND issued > ?
ORDER BY issued DESC`,
HashNames(req.Domains),
core.HashNames(req.Domains),
ssa.clk.Now().Add(-time.Duration(req.Window)),
)
if err != nil {
@ -586,7 +586,7 @@ type oneSelectorFunc func(ctx context.Context, holder interface{}, query string,
// checkFQDNSetExists uses the given oneSelectorFunc to check whether an fqdnSet
// for the given names exists.
func (ssa *SQLStorageAuthorityRO) checkFQDNSetExists(ctx context.Context, selector oneSelectorFunc, names []string) (bool, error) {
namehash := HashNames(names)
namehash := core.HashNames(names)
var exists bool
err := selector(
ctx,
@ -761,7 +761,7 @@ func (ssa *SQLStorageAuthorityRO) GetOrderForNames(ctx context.Context, req *sap
}
// Hash the names requested for lookup in the orderFqdnSets table
fqdnHash := HashNames(req.Names)
fqdnHash := core.HashNames(req.Names)
// Find a possibly-suitable order. We don't include the account ID or order
// status in this query because there's no index that includes those, so