Initial implementation of key-value rate limits (#6947)
This design seeks to reduce read-pressure on our DB by moving rate limit tabulation to a key-value datastore. This PR provides the following: - (README.md) a short guide to the schemas, formats, and concepts introduced in this PR - (source.go) an interface for storing, retrieving, and resetting a subscriber bucket - (name.go) an enumeration of all defined rate limits - (limit.go) a schema for defining default limits and per-subscriber overrides - (limiter.go) a high-level API for interacting with key-value rate limits - (gcra.go) an implementation of the Generic Cell Rate Algorithm, a leaky bucket-style scheduling algorithm, used to calculate the present or future capacity of a subscriber bucket using spend and refund operations Note: the included source implementation is test-only and currently accomplished using a simple in-memory map protected by a mutex, implementations using Redis and potentially other data stores will follow. Part of #5545
This commit is contained in:
parent
e955494955
commit
055f620c4b
|
@ -304,7 +304,7 @@ func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509
|
|||
}
|
||||
|
||||
func (m *mailer) certIsRenewed(ctx context.Context, names []string, issued time.Time) (bool, error) {
|
||||
namehash := sa.HashNames(names)
|
||||
namehash := core.HashNames(names)
|
||||
|
||||
var present bool
|
||||
err := m.dbMap.SelectOne(
|
||||
|
|
|
@ -353,7 +353,7 @@ func TestNoContactCertIsRenewed(t *testing.T) {
|
|||
setupDBMap, err := sa.DBMapForTest(vars.DBConnSAFullPerms)
|
||||
test.AssertNotError(t, err, "setting up DB")
|
||||
err = setupDBMap.Insert(ctx, &core.FQDNSet{
|
||||
SetHash: sa.HashNames(names),
|
||||
SetHash: core.HashNames(names),
|
||||
Serial: core.SerialToString(serial2),
|
||||
Issued: testCtx.fc.Now().Add(time.Hour),
|
||||
Expires: expires.Add(time.Hour),
|
||||
|
@ -580,13 +580,13 @@ func addExpiringCerts(t *testing.T, ctx *testCtx) []certDERWithRegID {
|
|||
test.AssertNotError(t, err, "creating cert D")
|
||||
|
||||
fqdnStatusD := &core.FQDNSet{
|
||||
SetHash: sa.HashNames(certDNames),
|
||||
SetHash: core.HashNames(certDNames),
|
||||
Serial: serial4String,
|
||||
Issued: ctx.fc.Now().AddDate(0, 0, -87),
|
||||
Expires: ctx.fc.Now().AddDate(0, 0, 3),
|
||||
}
|
||||
fqdnStatusDRenewed := &core.FQDNSet{
|
||||
SetHash: sa.HashNames(certDNames),
|
||||
SetHash: core.HashNames(certDNames),
|
||||
Serial: serial5String,
|
||||
Issued: ctx.fc.Now().AddDate(0, 0, -3),
|
||||
Expires: ctx.fc.Now().AddDate(0, 0, 87),
|
||||
|
@ -747,7 +747,7 @@ func TestCertIsRenewed(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
fqdnStatus := &core.FQDNSet{
|
||||
SetHash: sa.HashNames(testData.DNS),
|
||||
SetHash: core.HashNames(testData.DNS),
|
||||
Serial: testData.stringSerial,
|
||||
Issued: testData.NotBefore,
|
||||
Expires: testData.NotAfter,
|
||||
|
|
|
@ -242,6 +242,14 @@ func UniqueLowerNames(names []string) (unique []string) {
|
|||
return
|
||||
}
|
||||
|
||||
// HashNames returns a hash of the names requested. This is intended for use
|
||||
// when interacting with the orderFqdnSets table and rate limiting.
|
||||
func HashNames(names []string) []byte {
|
||||
names = UniqueLowerNames(names)
|
||||
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
// LoadCert loads a PEM certificate specified by filename or returns an error
|
||||
func LoadCert(filename string) (*x509.Certificate, error) {
|
||||
certPEM, err := os.ReadFile(filename)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
@ -206,3 +207,30 @@ func TestRetryBackoff(t *testing.T) {
|
|||
assertBetween(float64(backoff), float64(expected)*0.8, float64(expected)*1.2)
|
||||
|
||||
}
|
||||
|
||||
func TestHashNames(t *testing.T) {
|
||||
// Test that it is deterministic
|
||||
h1 := HashNames([]string{"a"})
|
||||
h2 := HashNames([]string{"a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it differentiates
|
||||
h1 = HashNames([]string{"a"})
|
||||
h2 = HashNames([]string{"b"})
|
||||
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
|
||||
|
||||
// Test that it is not subject to ordering
|
||||
h1 = HashNames([]string{"a", "b"})
|
||||
h2 = HashNames([]string{"b", "a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it is not subject to case
|
||||
h1 = HashNames([]string{"a", "b"})
|
||||
h2 = HashNames([]string{"A", "B"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it is not subject to duplication
|
||||
h1 = HashNames([]string{"a", "a"})
|
||||
h2 = HashNames([]string{"a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
}
|
||||
|
|
|
@ -196,7 +196,7 @@ var (
|
|||
errWildcardNotSupported = berrors.MalformedError("Wildcard domain names are not supported")
|
||||
)
|
||||
|
||||
// validDomain checks that a domain isn't:
|
||||
// ValidDomain checks that a domain isn't:
|
||||
//
|
||||
// * empty
|
||||
// * prefixed with the wildcard label `*.`
|
||||
|
@ -210,7 +210,7 @@ var (
|
|||
// * exactly equal to an IANA registered TLD
|
||||
//
|
||||
// It does _not_ check that the domain isn't on any PA blocked lists.
|
||||
func validDomain(domain string) error {
|
||||
func ValidDomain(domain string) error {
|
||||
if domain == "" {
|
||||
return errEmptyName
|
||||
}
|
||||
|
@ -323,7 +323,7 @@ func ValidEmail(address string) error {
|
|||
}
|
||||
splitEmail := strings.SplitN(email.Address, "@", -1)
|
||||
domain := strings.ToLower(splitEmail[len(splitEmail)-1])
|
||||
err = validDomain(domain)
|
||||
err = ValidDomain(domain)
|
||||
if err != nil {
|
||||
return berrors.InvalidEmailError(
|
||||
"contact email %q has invalid domain : %s",
|
||||
|
@ -363,7 +363,7 @@ func (pa *AuthorityImpl) willingToIssue(id identifier.ACMEIdentifier) error {
|
|||
}
|
||||
domain := id.Value
|
||||
|
||||
err := validDomain(domain)
|
||||
err := ValidDomain(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
# Configuring and Storing Key-Value Rate Limits
|
||||
|
||||
## Rate Limit Structure
|
||||
|
||||
All rate limits use a token-bucket model. The metaphor is that each limit is
|
||||
represented by a bucket which holds tokens. Each request removes some number of
|
||||
tokens from the bucket, or is denied if there aren't enough tokens to remove.
|
||||
Over time, new tokens are added to the bucket at a steady rate, until the bucket
|
||||
is full. The _burst_ parameter of a rate limit indicates the maximum capacity of
|
||||
a bucket: how many tokens can it hold before new ones stop being added.
|
||||
Therefore, this also indicates how many requests can be made in a single burst
|
||||
before a full bucket is completely emptied. The _count_ and _period_ parameters
|
||||
indicate the rate at which new tokens are added to a bucket: every period, count
|
||||
tokens will be added. Therefore, these also indicate the steady-state rate at
|
||||
which a client which has exhausted its quota can make requests: one token every
|
||||
(period / count) duration.
|
||||
|
||||
## Default Limit Settings
|
||||
|
||||
Each key directly corresponds to a `Name` enumeration as detailed in `//ratelimits/names.go`.
|
||||
The `Name` enum is used to identify the particular limit. The parameters of a
|
||||
default limit are the values that will be used for all buckets that do not have
|
||||
an explicit override (see below).
|
||||
|
||||
```yaml
|
||||
NewRegistrationsPerIPAddress:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
||||
NewOrdersPerAccount:
|
||||
burst: 300
|
||||
count: 300
|
||||
period: 180m
|
||||
```
|
||||
|
||||
## Override Limit Settings
|
||||
|
||||
Each override key represents a specific bucket, consisting of two elements:
|
||||
_name_ and _id_. The name here refers to the Name of the particular limit, while
|
||||
the id is a client identifier. The format of the id is dependent on the limit.
|
||||
For example, the id for 'NewRegistrationsPerIPAddress' is a subscriber IP
|
||||
address, while the id for 'NewOrdersPerAccount' is the subscriber's registration
|
||||
ID.
|
||||
|
||||
```yaml
|
||||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 20
|
||||
count: 40
|
||||
period: 1s
|
||||
NewOrdersPerAccount:12345678:
|
||||
burst: 300
|
||||
count: 600
|
||||
period: 180m
|
||||
```
|
||||
|
||||
The above example overrides the default limits for specific subscribers. In both
|
||||
cases the count of requests per period are doubled, but the burst capacity is
|
||||
explicitly configured to match the default rate limit.
|
||||
|
||||
### Id Formats in Limit Override Settings
|
||||
|
||||
Id formats vary based on the `Name` enumeration. Below are examples for each
|
||||
format:
|
||||
|
||||
#### ipAddress
|
||||
|
||||
A valid IPv4 or IPv6 address.
|
||||
|
||||
Examples:
|
||||
- `NewRegistrationsPerIPAddress:10.0.0.1`
|
||||
- `NewRegistrationsPerIPAddress:2001:0db8:0000:0000:0000:ff00:0042:8329`
|
||||
|
||||
#### ipv6RangeCIDR
|
||||
|
||||
A valid IPv6 range in CIDR notation with a /48 mask. A /48 range is typically
|
||||
assigned to a single subscriber.
|
||||
|
||||
Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48`
|
||||
|
||||
#### regId
|
||||
|
||||
The registration ID of the account.
|
||||
|
||||
Example: `NewOrdersPerAccount:12345678`
|
||||
|
||||
#### regId:domain
|
||||
|
||||
A combination of registration ID and domain, formatted 'regId:domain'.
|
||||
|
||||
Example: `CertificatesPerDomainPerAccount:12345678:example.com`
|
||||
|
||||
#### regId:fqdnSet
|
||||
|
||||
A combination of registration ID and a comma-separated list of domain names,
|
||||
formatted 'regId:fqdnSet'.
|
||||
|
||||
Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org`
|
||||
|
||||
## Bucket Key Definitions
|
||||
|
||||
A bucket key is used to lookup the bucket for a given limit and
|
||||
subscriber. Bucket keys are formatted similarly to the overrides but with a
|
||||
slight difference: the limit Names do not carry the string form of each limit.
|
||||
Instead, they apply the `Name` enum equivalent for every limit.
|
||||
|
||||
So, instead of:
|
||||
|
||||
```
|
||||
NewOrdersPerAccount:12345678
|
||||
```
|
||||
|
||||
The corresponding bucket key for regId 12345678 would look like this:
|
||||
|
||||
```
|
||||
6:12345678
|
||||
```
|
||||
|
||||
When loaded from a file, the keys for the default/override limits undergo the
|
||||
same interning process as the aforementioned subscriber bucket keys. This
|
||||
eliminates the need for redundant conversions when fetching each
|
||||
default/override limit.
|
||||
|
||||
## How Limits are Applied
|
||||
|
||||
Although rate limit buckets are configured in terms of tokens, we do not
|
||||
actually keep track of the number of tokens in each bucket. Instead, we track
|
||||
the Theoretical Arrival Time (TAT) at which the bucket will be full again. If
|
||||
the TAT is in the past, the bucket is full. If the TAT is in the future, some
|
||||
number of tokens have been spent and the bucket is slowly refilling. If the TAT
|
||||
is far enough in the future (specifically, more than `burst * (period / count)`)
|
||||
in the future), then the bucket is completely empty and requests will be denied.
|
||||
|
||||
Additional terminology:
|
||||
|
||||
- **burst offset** is the duration of time it takes for a bucket to go from
|
||||
empty to full (`burst * (period / count)`).
|
||||
- **emission interval** is the interval at which tokens are added to a bucket
|
||||
(`period / count`). This is also the steady-state rate at which requests can
|
||||
be made without being denied even once the burst has been exhausted.
|
||||
- **cost** is the number of tokens removed from a bucket for a single request.
|
||||
- **cost increment** is the duration of time the TAT is advanced to account
|
||||
for the cost of the request (`cost * emission interval`).
|
||||
|
||||
For the purposes of this example, subscribers originating from a specific IPv4
|
||||
address are allowed 20 requests to the newFoo endpoint per second, with a
|
||||
maximum burst of 20 requests at any point-in-time, or:
|
||||
|
||||
```yaml
|
||||
NewFoosPerIPAddress:172.23.45.22:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
||||
```
|
||||
|
||||
A subscriber calls the newFoo endpoint for the first time with an IP address of
|
||||
172.23.45.22. Here's what happens:
|
||||
|
||||
1. The subscriber's IP address is used to generate a bucket key in the form of
|
||||
'NewFoosPerIPAddress:172.23.45.22'.
|
||||
|
||||
2. The request is approved and the 'NewFoosPerIPAddress:172.23.45.22' bucket is
|
||||
initialized with 19 tokens, as 1 token has been removed to account for the
|
||||
cost of the current request. To accomplish this, the initial TAT is set to
|
||||
the current time plus the _cost increment_ (which is 1/20th of a second if we
|
||||
are limiting to 20 requests per second).
|
||||
|
||||
3. Bucket 'NewFoosPerIPAddress:172.23.45.22':
|
||||
- will reset to full in 50ms (1/20th of a second),
|
||||
- will allow another newFoo request immediately,
|
||||
- will allow between 1 and 19 more requests in the next 50ms,
|
||||
- will reject the 20th request made in the next 50ms,
|
||||
- and will allow 1 request every 50ms, indefinitely.
|
||||
|
||||
The subscriber makes another request 5ms later:
|
||||
|
||||
4. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is compared against
|
||||
the current time and the _burst offset_. The current time is greater than the
|
||||
TAT minus the cost increment. Therefore, the request is approved.
|
||||
|
||||
5. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is advanced by the
|
||||
cost increment to account for the cost of the request.
|
||||
|
||||
The subscriber makes a total of 18 requests over the next 44ms:
|
||||
|
||||
6. The current time is less than the TAT at bucket key
|
||||
'NewFoosPerIPAddress:172.23.45.22' minus the burst offset, thus the request
|
||||
is rejected.
|
||||
|
||||
This mechanism allows for bursts of traffic but also ensures that the average
|
||||
rate of requests stays within the prescribed limits over time.
|
|
@ -0,0 +1,110 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jmhodges/clock"
|
||||
)
|
||||
|
||||
// maybeSpend uses the GCRA algorithm to decide whether to allow a request. It
|
||||
// returns a Decision struct with the result of the decision and the updated
|
||||
// TAT. The cost must be 0 or greater and <= the burst capacity of the limit.
|
||||
func maybeSpend(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
|
||||
if cost < 0 || cost > rl.Burst {
|
||||
// The condition above is the union of the conditions checked in Check
|
||||
// and Spend methods of Limiter. If this panic is reached, it means that
|
||||
// the caller has introduced a bug.
|
||||
panic("invalid cost for maybeSpend")
|
||||
}
|
||||
nowUnix := clk.Now().UnixNano()
|
||||
tatUnix := tat.UnixNano()
|
||||
|
||||
// If the TAT is in the future, use it as the starting point for the
|
||||
// calculation. Otherwise, use the current time. This is to prevent the
|
||||
// bucket from being filled with capacity from the past.
|
||||
if nowUnix > tatUnix {
|
||||
tatUnix = nowUnix
|
||||
}
|
||||
|
||||
// Compute the cost increment.
|
||||
costIncrement := rl.emissionInterval * cost
|
||||
|
||||
// Deduct the cost to find the new TAT and residual capacity.
|
||||
newTAT := tatUnix + costIncrement
|
||||
difference := nowUnix - (newTAT - rl.burstOffset)
|
||||
|
||||
if difference < 0 {
|
||||
// Too little capacity to satisfy the cost, deny the request.
|
||||
residual := (nowUnix - (tatUnix - rl.burstOffset)) / rl.emissionInterval
|
||||
return &Decision{
|
||||
Allowed: false,
|
||||
Remaining: residual,
|
||||
RetryIn: -time.Duration(difference),
|
||||
ResetIn: time.Duration(tatUnix - nowUnix),
|
||||
newTAT: time.Unix(0, tatUnix).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// There is enough capacity to satisfy the cost, allow the request.
|
||||
var retryIn time.Duration
|
||||
residual := difference / rl.emissionInterval
|
||||
if difference < costIncrement {
|
||||
retryIn = time.Duration(costIncrement - difference)
|
||||
}
|
||||
return &Decision{
|
||||
Allowed: true,
|
||||
Remaining: residual,
|
||||
RetryIn: retryIn,
|
||||
ResetIn: time.Duration(newTAT - nowUnix),
|
||||
newTAT: time.Unix(0, newTAT).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// maybeRefund uses the Generic Cell Rate Algorithm (GCRA) to attempt to refund
|
||||
// the cost of a request which was previously spent. The refund cost must be 0
|
||||
// or greater. A cost will only be refunded up to the burst capacity of the
|
||||
// limit. A partial refund is still considered successful.
|
||||
func maybeRefund(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
|
||||
if cost <= 0 || cost > rl.Burst {
|
||||
// The condition above is checked in the Refund method of Limiter. If
|
||||
// this panic is reached, it means that the caller has introduced a bug.
|
||||
panic("invalid cost for maybeRefund")
|
||||
}
|
||||
nowUnix := clk.Now().UnixNano()
|
||||
tatUnix := tat.UnixNano()
|
||||
|
||||
// The TAT must be in the future to refund capacity.
|
||||
if nowUnix > tatUnix {
|
||||
// The TAT is in the past, therefore the bucket is full.
|
||||
return &Decision{
|
||||
Allowed: false,
|
||||
Remaining: rl.Burst,
|
||||
RetryIn: time.Duration(0),
|
||||
ResetIn: time.Duration(0),
|
||||
newTAT: tat,
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the refund increment.
|
||||
refundIncrement := rl.emissionInterval * cost
|
||||
|
||||
// Subtract the refund increment from the TAT to find the new TAT.
|
||||
newTAT := tatUnix - refundIncrement
|
||||
|
||||
// Ensure the new TAT is not earlier than now.
|
||||
if newTAT < nowUnix {
|
||||
newTAT = nowUnix
|
||||
}
|
||||
|
||||
// Calculate the new capacity.
|
||||
difference := nowUnix - (newTAT - rl.burstOffset)
|
||||
residual := difference / rl.emissionInterval
|
||||
|
||||
return &Decision{
|
||||
Allowed: (newTAT != tatUnix),
|
||||
Remaining: residual,
|
||||
RetryIn: time.Duration(0),
|
||||
ResetIn: time.Duration(newTAT - nowUnix),
|
||||
newTAT: time.Unix(0, newTAT).UTC(),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmhodges/clock"
|
||||
"github.com/letsencrypt/boulder/config"
|
||||
"github.com/letsencrypt/boulder/test"
|
||||
)
|
||||
|
||||
func Test_decide(t *testing.T) {
|
||||
clk := clock.NewFake()
|
||||
limit := precomputeLimit(
|
||||
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||
)
|
||||
|
||||
// Begin by using 1 of our 10 requests.
|
||||
d := maybeSpend(clk, limit, clk.Now(), 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(9))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Immediately use another 9 of our remaining requests.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 9)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
// We should have to wait 1 second before we can use another request but we
|
||||
// used 9 so we should have to wait 9 seconds to make an identical request.
|
||||
test.AssertEquals(t, d.RetryIn, time.Second*9)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Our new TAT should be 10 seconds (limit.Burst) in the future.
|
||||
test.AssertEquals(t, d.newTAT, clk.Now().Add(time.Second*10))
|
||||
|
||||
// Let's try using just 1 more request without waiting.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Let's try being exactly as patient as we're told to be.
|
||||
clk.Add(d.RetryIn)
|
||||
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||
test.AssertEquals(t, d.Remaining, int64(1))
|
||||
|
||||
// We are 1 second in the future, we should have 1 new request.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Let's try waiting (10 seconds) for our whole bucket to refill.
|
||||
clk.Add(d.ResetIn)
|
||||
|
||||
// We should have 10 new requests. If we use 1 we should have 9 remaining.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(9))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Wait just shy of how long we're told to wait for refilling.
|
||||
clk.Add(d.ResetIn - time.Millisecond)
|
||||
|
||||
// We should still have 9 remaining because we're still 1ms shy of the
|
||||
// refill time.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(9))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond)
|
||||
|
||||
// Spending 0 simply informed us that we still have 9 remaining, let's see
|
||||
// what we have after waiting 20 hours.
|
||||
clk.Add(20 * time.Hour)
|
||||
|
||||
// C'mon, big money, no whammies, no whammies, STOP!
|
||||
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
|
||||
// Turns out that the most we can accrue is 10 (limit.Burst). Let's empty
|
||||
// this bucket out so we can try something else.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 10)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
// We should have to wait 1 second before we can use another request but we
|
||||
// used 10 so we should have to wait 10 seconds to make an identical
|
||||
// request.
|
||||
test.AssertEquals(t, d.RetryIn, time.Second*10)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// If you spend 0 while you have 0 you should get 0.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// We don't play by the rules, we spend 1 when we have 0.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Okay, maybe we should play by the rules if we want to get anywhere.
|
||||
clk.Add(d.RetryIn)
|
||||
|
||||
// Our patience pays off, we should have 1 new request. Let's use it.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Refill from empty to 5.
|
||||
clk.Add(d.ResetIn / 2)
|
||||
|
||||
// Attempt to spend 7 when we only have 5. We should be denied but the
|
||||
// decision should reflect a retry of 2 seconds, the time it would take to
|
||||
// refill from 5 to 7.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 7)
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(5))
|
||||
test.AssertEquals(t, d.RetryIn, time.Second*2)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*5)
|
||||
}
|
||||
|
||||
func Test_maybeRefund(t *testing.T) {
|
||||
clk := clock.NewFake()
|
||||
limit := precomputeLimit(
|
||||
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||
)
|
||||
|
||||
// Begin by using 1 of our 10 requests.
|
||||
d := maybeSpend(clk, limit, clk.Now(), 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(9))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Refund back to 10.
|
||||
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
|
||||
// Spend 1 more of our 10 requests.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(9))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Wait for our bucket to refill.
|
||||
clk.Add(d.ResetIn)
|
||||
|
||||
// Attempt to refund from 10 to 11.
|
||||
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
|
||||
// Spend 10 all 10 of our requests.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 10)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
// We should have to wait 1 second before we can use another request but we
|
||||
// used 10 so we should have to wait 10 seconds to make an identical
|
||||
// request.
|
||||
test.AssertEquals(t, d.RetryIn, time.Second*10)
|
||||
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||
|
||||
// Attempt a refund of 10.
|
||||
d = maybeRefund(clk, limit, d.newTAT, 10)
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
|
||||
// Wait 11 seconds to catching up to TAT.
|
||||
clk.Add(11 * time.Second)
|
||||
|
||||
// Attempt to refund to 11, then ensure it's still 10.
|
||||
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, !d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
|
||||
// Spend 5 of our 10 requests, then refund 1.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 5)
|
||||
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(6))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
|
||||
// Wait, a 2.5 seconds to refill to 8.5 requests.
|
||||
clk.Add(time.Millisecond * 2500)
|
||||
|
||||
// Ensure we have 8.5 requests.
|
||||
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(8))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
// Check that ResetIn represents the fractional earned request.
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond*1500)
|
||||
|
||||
// Refund 2 requests, we should only have 10, not 10.5.
|
||||
d = maybeRefund(clk, limit, d.newTAT, 2)
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/letsencrypt/boulder/config"
|
||||
"github.com/letsencrypt/boulder/core"
|
||||
"github.com/letsencrypt/boulder/strictyaml"
|
||||
)
|
||||
|
||||
type limit struct {
|
||||
// Burst specifies maximum concurrent allowed requests at any given time. It
|
||||
// must be greater than zero.
|
||||
Burst int64
|
||||
|
||||
// Count is the number of requests allowed per period. It must be greater
|
||||
// than zero.
|
||||
Count int64
|
||||
|
||||
// Period is the duration of time in which the count (of requests) is
|
||||
// allowed. It must be greater than zero.
|
||||
Period config.Duration
|
||||
|
||||
// emissionInterval is the interval, in nanoseconds, at which tokens are
|
||||
// added to a bucket (period / count). This is also the steady-state rate at
|
||||
// which requests can be made without being denied even once the burst has
|
||||
// been exhausted. This is precomputed to avoid doing the same calculation
|
||||
// on every request.
|
||||
emissionInterval int64
|
||||
|
||||
// burstOffset is the duration of time, in nanoseconds, it takes for a
|
||||
// bucket to go from empty to full (burst * (period / count)). This is
|
||||
// precomputed to avoid doing the same calculation on every request.
|
||||
burstOffset int64
|
||||
}
|
||||
|
||||
func precomputeLimit(l limit) limit {
|
||||
l.emissionInterval = l.Period.Nanoseconds() / l.Count
|
||||
l.burstOffset = l.emissionInterval * l.Burst
|
||||
return l
|
||||
}
|
||||
|
||||
func validateLimit(l limit) error {
|
||||
if l.Burst <= 0 {
|
||||
return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst)
|
||||
}
|
||||
if l.Count <= 0 {
|
||||
return fmt.Errorf("invalid count '%d', must be > 0", l.Count)
|
||||
}
|
||||
if l.Period.Duration <= 0 {
|
||||
return fmt.Errorf("invalid period '%s', must be > 0", l.Period)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type limits map[string]limit
|
||||
|
||||
// loadLimits marshals the YAML file at path into a map of limis.
|
||||
func loadLimits(path string) (limits, error) {
|
||||
lm := make(limits)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = strictyaml.Unmarshal(data, &lm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lm, nil
|
||||
}
|
||||
|
||||
// parseOverrideNameId is broken out for ease of testing.
|
||||
func parseOverrideNameId(key string) (Name, string, error) {
|
||||
if !strings.Contains(key, ":") {
|
||||
// Avoids a potential panic in strings.SplitN below.
|
||||
return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
|
||||
}
|
||||
nameAndId := strings.SplitN(key, ":", 2)
|
||||
nameStr := nameAndId[0]
|
||||
if nameStr == "" {
|
||||
return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key)
|
||||
}
|
||||
|
||||
name, ok := stringToName[nameStr]
|
||||
if !ok {
|
||||
return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, limitNames)
|
||||
}
|
||||
id := nameAndId[1]
|
||||
if id == "" {
|
||||
return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key)
|
||||
}
|
||||
return name, id, nil
|
||||
}
|
||||
|
||||
// loadAndParseOverrideLimits loads override limits from YAML, validates them,
|
||||
// and parses them into a map of limits keyed by 'Name:id'.
|
||||
func loadAndParseOverrideLimits(path string) (limits, error) {
|
||||
fromFile, err := loadLimits(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed := make(limits, len(fromFile))
|
||||
|
||||
for k, v := range fromFile {
|
||||
err = validateLimit(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating override limit %q: %w", k, err)
|
||||
}
|
||||
name, id, err := parseOverrideNameId(k)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing override limit %q: %w", k, err)
|
||||
}
|
||||
err = validateIdForName(name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"validating name %s and id %q for override limit %q: %w", nameToString[name], id, k, err)
|
||||
}
|
||||
if name == CertificatesPerFQDNSetPerAccount {
|
||||
// FQDNSet hashes are not a nice thing to ask for in a config file,
|
||||
// so we allow the user to specify a comma-separated list of FQDNs
|
||||
// and compute the hash here.
|
||||
regIdDomains := strings.SplitN(id, ":", 2)
|
||||
if len(regIdDomains) != 2 {
|
||||
// Should never happen, the Id format was validated above.
|
||||
return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k)
|
||||
}
|
||||
regId := regIdDomains[0]
|
||||
domains := strings.Split(regIdDomains[1], ",")
|
||||
fqdnSet := core.HashNames(domains)
|
||||
id = fmt.Sprintf("%s:%s", regId, fqdnSet)
|
||||
}
|
||||
parsed[bucketKey(name, id)] = precomputeLimit(v)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// loadAndParseDefaultLimits loads default limits from YAML, validates them, and
|
||||
// parses them into a map of limits keyed by 'Name'.
|
||||
func loadAndParseDefaultLimits(path string) (limits, error) {
|
||||
fromFile, err := loadLimits(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed := make(limits, len(fromFile))
|
||||
|
||||
for k, v := range fromFile {
|
||||
err := validateLimit(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing default limit %q: %w", k, err)
|
||||
}
|
||||
name, ok := stringToName[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames)
|
||||
}
|
||||
parsed[nameToEnumString(name)] = precomputeLimit(v)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
|
@ -0,0 +1,333 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/letsencrypt/boulder/config"
|
||||
"github.com/letsencrypt/boulder/core"
|
||||
"github.com/letsencrypt/boulder/test"
|
||||
)
|
||||
|
||||
func Test_parseOverrideNameId(t *testing.T) {
|
||||
newRegistrationsPerIPAddressStr := nameToString[NewRegistrationsPerIPAddress]
|
||||
newRegistrationsPerIPv6RangeStr := nameToString[NewRegistrationsPerIPv6Range]
|
||||
|
||||
// 'enum:ipv4'
|
||||
// Valid IPv4 address.
|
||||
name, id, err := parseOverrideNameId(newRegistrationsPerIPAddressStr + ":10.0.0.1")
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.AssertEquals(t, name, NewRegistrationsPerIPAddress)
|
||||
test.AssertEquals(t, id, "10.0.0.1")
|
||||
|
||||
// 'enum:ipv6range'
|
||||
// Valid IPv6 address range.
|
||||
name, id, err = parseOverrideNameId(newRegistrationsPerIPv6RangeStr + ":2001:0db8:0000::/48")
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.AssertEquals(t, name, NewRegistrationsPerIPv6Range)
|
||||
test.AssertEquals(t, id, "2001:0db8:0000::/48")
|
||||
|
||||
// Missing colon (this should never happen but we should avoid panicking).
|
||||
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + "10.0.0.1")
|
||||
test.AssertError(t, err, "missing colon")
|
||||
|
||||
// Empty string.
|
||||
_, _, err = parseOverrideNameId("")
|
||||
test.AssertError(t, err, "empty string")
|
||||
|
||||
// Only a colon.
|
||||
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + ":")
|
||||
test.AssertError(t, err, "only a colon")
|
||||
|
||||
// Invalid enum.
|
||||
_, _, err = parseOverrideNameId("lol:noexist")
|
||||
test.AssertError(t, err, "invalid enum")
|
||||
}
|
||||
|
||||
func Test_validateLimit(t *testing.T) {
|
||||
err := validateLimit(limit{Burst: 1, Count: 1, Period: config.Duration{Duration: time.Second}})
|
||||
test.AssertNotError(t, err, "valid limit")
|
||||
|
||||
// All of the following are invalid.
|
||||
for _, l := range []limit{
|
||||
{Burst: 0, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||
{Burst: 1, Count: 0, Period: config.Duration{Duration: time.Second}},
|
||||
{Burst: 1, Count: 1, Period: config.Duration{Duration: 0}},
|
||||
} {
|
||||
err = validateLimit(l)
|
||||
test.AssertError(t, err, "limit should be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validateIdForName(t *testing.T) {
|
||||
// 'enum:ipAddress'
|
||||
// Valid IPv4 address.
|
||||
err := validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.1")
|
||||
test.AssertNotError(t, err, "valid ipv4 address")
|
||||
|
||||
// 'enum:ipAddress'
|
||||
// Valid IPv6 address.
|
||||
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334")
|
||||
test.AssertNotError(t, err, "valid ipv6 address")
|
||||
|
||||
// 'enum:ipv6rangeCIDR'
|
||||
// Valid IPv6 address range.
|
||||
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48")
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
|
||||
// 'enum:regId'
|
||||
// Valid regId.
|
||||
err = validateIdForName(NewOrdersPerAccount, "1234567890")
|
||||
test.AssertNotError(t, err, "valid regId")
|
||||
|
||||
// 'enum:regId:domain'
|
||||
// Valid regId and domain.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com")
|
||||
test.AssertNotError(t, err, "valid regId and domain")
|
||||
|
||||
// 'enum:regId:fqdnSet'
|
||||
// Valid regId and FQDN set containing a single domain.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com")
|
||||
test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain")
|
||||
|
||||
// 'enum:regId:fqdnSet'
|
||||
// Valid regId and FQDN set containing multiple domains.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org")
|
||||
test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains")
|
||||
|
||||
// Empty string.
|
||||
err = validateIdForName(NewRegistrationsPerIPAddress, "")
|
||||
test.AssertError(t, err, "Id is an empty string")
|
||||
|
||||
// One space.
|
||||
err = validateIdForName(NewRegistrationsPerIPAddress, " ")
|
||||
test.AssertError(t, err, "Id is a single space")
|
||||
|
||||
// Invalid IPv4 address.
|
||||
err = validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.9000")
|
||||
test.AssertError(t, err, "invalid IPv4 address")
|
||||
|
||||
// Invalid IPv6 address.
|
||||
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334:9000")
|
||||
test.AssertError(t, err, "invalid IPv6 address")
|
||||
|
||||
// Invalid IPv6 CIDR range.
|
||||
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/128")
|
||||
test.AssertError(t, err, "invalid IPv6 CIDR range")
|
||||
|
||||
// Invalid IPv6 CIDR.
|
||||
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48/48")
|
||||
test.AssertError(t, err, "invalid IPv6 CIDR")
|
||||
|
||||
// IPv4 CIDR when we expect IPv6 CIDR range.
|
||||
err = validateIdForName(NewRegistrationsPerIPv6Range, "10.0.0.0/16")
|
||||
test.AssertError(t, err, "ipv4 cidr when we expect ipv6 cidr range")
|
||||
|
||||
// Invalid regId.
|
||||
err = validateIdForName(NewOrdersPerAccount, "lol")
|
||||
test.AssertError(t, err, "invalid regId")
|
||||
|
||||
// Invalid regId with good domain.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com")
|
||||
test.AssertError(t, err, "invalid regId with good domain")
|
||||
|
||||
// Valid regId with bad domain.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol")
|
||||
test.AssertError(t, err, "valid regId with bad domain")
|
||||
|
||||
// Empty regId with good domain.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, ":lol")
|
||||
test.AssertError(t, err, "valid regId with bad domain")
|
||||
|
||||
// Valid regId with empty domain.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:")
|
||||
test.AssertError(t, err, "valid regId with empty domain")
|
||||
|
||||
// Empty regId with empty domain, no separator.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "")
|
||||
test.AssertError(t, err, "empty regId with empty domain, no separator")
|
||||
|
||||
// Instead of anything we would expect, we get lol.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "lol")
|
||||
test.AssertError(t, err, "instead of anything we would expect, just lol")
|
||||
|
||||
// Valid regId with good domain and a secret third separator.
|
||||
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol")
|
||||
test.AssertError(t, err, "valid regId with good domain and a secret third separator")
|
||||
|
||||
// Valid regId with bad FQDN set.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99")
|
||||
test.AssertError(t, err, "valid regId with bad FQDN set")
|
||||
|
||||
// Bad regId with good FQDN set.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org")
|
||||
test.AssertError(t, err, "bad regId with good FQDN set")
|
||||
|
||||
// Empty regId with good FQDN set.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org")
|
||||
test.AssertError(t, err, "empty regId with good FQDN set")
|
||||
|
||||
// Good regId with empty FQDN set.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:")
|
||||
test.AssertError(t, err, "good regId with empty FQDN set")
|
||||
|
||||
// Empty regId with empty FQDN set, no separator.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "")
|
||||
test.AssertError(t, err, "empty regId with empty FQDN set, no separator")
|
||||
|
||||
// Instead of anything we would expect, just lol.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol")
|
||||
test.AssertError(t, err, "instead of anything we would expect, just lol")
|
||||
|
||||
// Valid regId with good FQDN set and a secret third separator.
|
||||
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol")
|
||||
test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator")
|
||||
}
|
||||
|
||||
func Test_loadAndParseOverrideLimits(t *testing.T) {
|
||||
newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
|
||||
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
|
||||
|
||||
// Load a single valid override limit with Id formatted as 'enum:RegId'.
|
||||
l, err := loadAndParseOverrideLimits("testdata/working_override.yml")
|
||||
test.AssertNotError(t, err, "valid single override limit")
|
||||
expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
|
||||
test.AssertEquals(t, l[expectKey].Burst, int64(40))
|
||||
test.AssertEquals(t, l[expectKey].Count, int64(40))
|
||||
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
|
||||
|
||||
// Load single valid override limit with Id formatted as 'regId:domain'.
|
||||
l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml")
|
||||
test.AssertNotError(t, err, "valid single override limit with Id of regId:domain")
|
||||
expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com"
|
||||
test.AssertEquals(t, l[expectKey].Burst, int64(40))
|
||||
test.AssertEquals(t, l[expectKey].Count, int64(40))
|
||||
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
|
||||
|
||||
// Load multiple valid override limits with 'enum:RegId' Ids.
|
||||
l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml")
|
||||
expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
|
||||
test.AssertNotError(t, err, "multiple valid override limits")
|
||||
test.AssertEquals(t, l[expectKey1].Burst, int64(40))
|
||||
test.AssertEquals(t, l[expectKey1].Count, int64(40))
|
||||
test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second)
|
||||
expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48"
|
||||
test.AssertEquals(t, l[expectKey2].Burst, int64(50))
|
||||
test.AssertEquals(t, l[expectKey2].Count, int64(50))
|
||||
test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2)
|
||||
|
||||
// Load multiple valid override limits with 'regID:fqdnSet' Ids as follows:
|
||||
// - CertificatesPerFQDNSetPerAccount:12345678:example.com
|
||||
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net
|
||||
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org
|
||||
firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"}))
|
||||
secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"}))
|
||||
thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"}))
|
||||
firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash
|
||||
secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash
|
||||
thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash
|
||||
l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml")
|
||||
test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets")
|
||||
test.AssertEquals(t, l[firstEntryKey].Burst, int64(40))
|
||||
test.AssertEquals(t, l[firstEntryKey].Count, int64(40))
|
||||
test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second)
|
||||
test.AssertEquals(t, l[secondEntryKey].Burst, int64(50))
|
||||
test.AssertEquals(t, l[secondEntryKey].Count, int64(50))
|
||||
test.AssertEquals(t, l[secondEntryKey].Period.Duration, time.Second*2)
|
||||
test.AssertEquals(t, l[thirdEntryKey].Burst, int64(60))
|
||||
test.AssertEquals(t, l[thirdEntryKey].Count, int64(60))
|
||||
test.AssertEquals(t, l[thirdEntryKey].Period.Duration, time.Second*3)
|
||||
|
||||
// Path is empty string.
|
||||
_, err = loadAndParseOverrideLimits("")
|
||||
test.AssertError(t, err, "path is empty string")
|
||||
test.Assert(t, os.IsNotExist(err), "path is empty string")
|
||||
|
||||
// Path to file which does not exist.
|
||||
_, err = loadAndParseOverrideLimits("testdata/file_does_not_exist.yml")
|
||||
test.AssertError(t, err, "a file that does not exist ")
|
||||
test.Assert(t, os.IsNotExist(err), "test file should not exist")
|
||||
|
||||
// Burst cannot be 0.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_override_burst_0.yml")
|
||||
test.AssertError(t, err, "single override limit with burst=0")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Id cannot be empty.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_id.yml")
|
||||
test.AssertError(t, err, "single override limit with empty id")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Name cannot be empty.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_name.yml")
|
||||
test.AssertError(t, err, "single override limit with empty name")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Name must be a string representation of a valid Name enumeration.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_override_invalid_name.yml")
|
||||
test.AssertError(t, err, "single override limit with invalid name")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Multiple entries, second entry has a bad name.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_second_entry_bad_name.yml")
|
||||
test.AssertError(t, err, "multiple override limits, second entry is bad")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Multiple entries, third entry has id of "lol", instead of an IPv4 address.
|
||||
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_third_entry_bad_id.yml")
|
||||
test.AssertError(t, err, "multiple override limits, third entry has bad Id value")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
}
|
||||
|
||||
func Test_loadAndParseDefaultLimits(t *testing.T) {
|
||||
newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
|
||||
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
|
||||
|
||||
// Load a single valid default limit.
|
||||
l, err := loadAndParseDefaultLimits("testdata/working_default.yml")
|
||||
test.AssertNotError(t, err, "valid single default limit")
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
|
||||
|
||||
// Load multiple valid default limits.
|
||||
l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml")
|
||||
test.AssertNotError(t, err, "multiple valid default limits")
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
|
||||
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
|
||||
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30))
|
||||
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30))
|
||||
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2)
|
||||
|
||||
// Path is empty string.
|
||||
_, err = loadAndParseDefaultLimits("")
|
||||
test.AssertError(t, err, "path is empty string")
|
||||
test.Assert(t, os.IsNotExist(err), "path is empty string")
|
||||
|
||||
// Path to file which does not exist.
|
||||
_, err = loadAndParseDefaultLimits("testdata/file_does_not_exist.yml")
|
||||
test.AssertError(t, err, "a file that does not exist")
|
||||
test.Assert(t, os.IsNotExist(err), "test file should not exist")
|
||||
|
||||
// Burst cannot be 0.
|
||||
_, err = loadAndParseDefaultLimits("testdata/busted_default_burst_0.yml")
|
||||
test.AssertError(t, err, "single default limit with burst=0")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Name cannot be empty.
|
||||
_, err = loadAndParseDefaultLimits("testdata/busted_default_empty_name.yml")
|
||||
test.AssertError(t, err, "single default limit with empty name")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Name must be a string representation of a valid Name enumeration.
|
||||
_, err = loadAndParseDefaultLimits("testdata/busted_default_invalid_name.yml")
|
||||
test.AssertError(t, err, "single default limit with invalid name")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
|
||||
// Multiple entries, second entry has a bad name.
|
||||
_, err = loadAndParseDefaultLimits("testdata/busted_defaults_second_entry_bad_name.yml")
|
||||
test.AssertError(t, err, "multiple default limits, one is bad")
|
||||
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||
}
|
|
@ -0,0 +1,234 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmhodges/clock"
|
||||
)
|
||||
|
||||
// ErrInvalidCost indicates that the cost specified was <= 0.
|
||||
var ErrInvalidCost = fmt.Errorf("invalid cost, must be > 0")
|
||||
|
||||
// ErrInvalidCostForCheck indicates that the check cost specified was < 0.
|
||||
var ErrInvalidCostForCheck = fmt.Errorf("invalid check cost, must be >= 0")
|
||||
|
||||
// ErrInvalidCostOverLimit indicates that the cost specified was > limit.Burst.
|
||||
var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst")
|
||||
|
||||
// ErrBucketAlreadyFull indicates that the bucket already has reached its
|
||||
// maximum capacity.
|
||||
var ErrBucketAlreadyFull = fmt.Errorf("bucket already full")
|
||||
|
||||
// Limiter provides a high-level interface for rate limiting requests by
|
||||
// utilizing a leaky bucket-style approach.
|
||||
type Limiter struct {
|
||||
// defaults stores default limits by 'name'.
|
||||
defaults limits
|
||||
|
||||
// overrides stores override limits by 'name:id'.
|
||||
overrides limits
|
||||
|
||||
// source is used to store buckets. It must be safe for concurrent use.
|
||||
source source
|
||||
clk clock.Clock
|
||||
}
|
||||
|
||||
// NewLimiter returns a new *Limiter. The provided source must be safe for
|
||||
// concurrent use. The defaults and overrides paths are expected to be paths to
|
||||
// YAML files that contain the default and override limits, respectively. The
|
||||
// overrides file is optional, all other arguments are required.
|
||||
func NewLimiter(clk clock.Clock, source source, defaults, overrides string) (*Limiter, error) {
|
||||
limiter := &Limiter{source: source, clk: clk}
|
||||
|
||||
var err error
|
||||
limiter.defaults, err = loadAndParseDefaultLimits(defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if overrides == "" {
|
||||
// No overrides specified, initialize an empty map.
|
||||
limiter.overrides = make(limits)
|
||||
return limiter, nil
|
||||
}
|
||||
|
||||
limiter.overrides, err = loadAndParseOverrideLimits(overrides)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return limiter, nil
|
||||
}
|
||||
|
||||
type Decision struct {
|
||||
// Allowed is true if the bucket possessed enough capacity to allow the
|
||||
// request given the cost.
|
||||
Allowed bool
|
||||
|
||||
// Remaining is the number of requests the client is allowed to make before
|
||||
// they're rate limited.
|
||||
Remaining int64
|
||||
|
||||
// RetryIn is the duration the client MUST wait before they're allowed to
|
||||
// make a request.
|
||||
RetryIn time.Duration
|
||||
|
||||
// ResetIn is the duration the bucket will take to refill to its maximum
|
||||
// capacity, assuming no further requests are made.
|
||||
ResetIn time.Duration
|
||||
|
||||
// newTAT indicates the time at which the bucket will be full. It is the
|
||||
// theoretical arrival time (TAT) of next request. It must be no more than
|
||||
// (burst * (period / count)) in the future at any single point in time.
|
||||
newTAT time.Time
|
||||
}
|
||||
|
||||
// Check returns a *Decision that indicates whether there's enough capacity to
|
||||
// allow the request, given the cost, for the specified limit Name and client
|
||||
// id. However, it DOES NOT deduct the cost of the request from the bucket's
|
||||
// capacity. Hence, the returned *Decision represents the hypothetical state of
|
||||
// the bucket if the cost WERE to be deducted. The returned *Decision will
|
||||
// always include the number of remaining requests in the bucket, the required
|
||||
// wait time before the client can make another request, and the time until the
|
||||
// bucket refills to its maximum capacity (resets). If no bucket exists for the
|
||||
// given limit Name and client id, a new one will be created WITHOUT the
|
||||
// request's cost deducted from its initial capacity.
|
||||
func (l *Limiter) Check(name Name, id string, cost int64) (*Decision, error) {
|
||||
if cost < 0 {
|
||||
return nil, ErrInvalidCostForCheck
|
||||
}
|
||||
|
||||
limit, err := l.getLimit(name, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cost > limit.Burst {
|
||||
return nil, ErrInvalidCostOverLimit
|
||||
}
|
||||
|
||||
tat, err := l.source.Get(bucketKey(name, id))
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrBucketNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
// First request from this client. The cost is not deducted from the
|
||||
// initial capacity because this is only a check.
|
||||
d, err := l.initialize(limit, name, id, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return maybeSpend(l.clk, limit, d.newTAT, cost), nil
|
||||
}
|
||||
return maybeSpend(l.clk, limit, tat, cost), nil
|
||||
}
|
||||
|
||||
// Spend returns a *Decision that indicates if enough capacity was available to
|
||||
// process the request, given the cost, for the specified limit Name and client
|
||||
// id. If capacity existed, the cost of the request HAS been deducted from the
|
||||
// bucket's capacity, otherwise no cost was deducted. The returned *Decision
|
||||
// will always include the number of remaining requests in the bucket, the
|
||||
// required wait time before the client can make another request, and the time
|
||||
// until the bucket refills to its maximum capacity (resets). If no bucket
|
||||
// exists for the given limit Name and client id, a new one will be created WITH
|
||||
// the request's cost deducted from its initial capacity.
|
||||
func (l *Limiter) Spend(name Name, id string, cost int64) (*Decision, error) {
|
||||
if cost <= 0 {
|
||||
return nil, ErrInvalidCost
|
||||
}
|
||||
|
||||
limit, err := l.getLimit(name, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cost > limit.Burst {
|
||||
return nil, ErrInvalidCostOverLimit
|
||||
}
|
||||
|
||||
tat, err := l.source.Get(bucketKey(name, id))
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrBucketNotFound) {
|
||||
// First request from this client.
|
||||
return l.initialize(limit, name, id, cost)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := maybeSpend(l.clk, limit, tat, cost)
|
||||
|
||||
if !d.Allowed {
|
||||
return d, nil
|
||||
}
|
||||
return d, l.source.Set(bucketKey(name, id), d.newTAT)
|
||||
}
|
||||
|
||||
// Refund attempts to refund the cost to the bucket identified by limit name and
|
||||
// client id. The returned *Decision indicates whether the refund was successful
|
||||
// or not. If the refund was successful, the cost of the request was added back
|
||||
// to the bucket's capacity. If the refund is not possible (i.e., the bucket is
|
||||
// already full or the refund amount is invalid), no cost is refunded.
|
||||
//
|
||||
// Note: The amount refunded cannot cause the bucket to exceed its maximum
|
||||
// capacity. However, partial refunds are allowed and are considered successful.
|
||||
// For instance, if a bucket has a maximum capacity of 10 and currently has 5
|
||||
// requests remaining, a refund request of 7 will result in the bucket reaching
|
||||
// its maximum capacity of 10, not 12.
|
||||
func (l *Limiter) Refund(name Name, id string, cost int64) (*Decision, error) {
|
||||
if cost <= 0 {
|
||||
return nil, ErrInvalidCost
|
||||
}
|
||||
|
||||
limit, err := l.getLimit(name, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tat, err := l.source.Get(bucketKey(name, id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := maybeRefund(l.clk, limit, tat, cost)
|
||||
if !d.Allowed {
|
||||
return d, ErrBucketAlreadyFull
|
||||
}
|
||||
return d, l.source.Set(bucketKey(name, id), d.newTAT)
|
||||
|
||||
}
|
||||
|
||||
// Reset resets the specified bucket.
|
||||
func (l *Limiter) Reset(name Name, id string) error {
|
||||
return l.source.Delete(bucketKey(name, id))
|
||||
}
|
||||
|
||||
// initialize creates a new bucket, specified by limit name and id, with the
|
||||
// cost of the request factored into the initial state.
|
||||
func (l *Limiter) initialize(rl limit, name Name, id string, cost int64) (*Decision, error) {
|
||||
d := maybeSpend(l.clk, rl, l.clk.Now(), cost)
|
||||
err := l.source.Set(bucketKey(name, id), d.newTAT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
}
|
||||
|
||||
// GetLimit returns the limit for the specified by name and id, name is
|
||||
// required, id is optional. If id is left unspecified, the default limit for
|
||||
// the limit specified by name is returned.
|
||||
func (l *Limiter) getLimit(name Name, id string) (limit, error) {
|
||||
if id != "" {
|
||||
// Check for override.
|
||||
ol, ok := l.overrides[bucketKey(name, id)]
|
||||
if ok {
|
||||
return ol, nil
|
||||
}
|
||||
}
|
||||
dl, ok := l.defaults[nameToEnumString(name)]
|
||||
if ok {
|
||||
return dl, nil
|
||||
}
|
||||
return limit{}, fmt.Errorf("limit %q does not exist", name)
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmhodges/clock"
|
||||
"github.com/letsencrypt/boulder/test"
|
||||
)
|
||||
|
||||
const (
|
||||
tenZeroZeroOne = "10.0.0.1"
|
||||
tenZeroZeroTwo = "10.0.0.2"
|
||||
)
|
||||
|
||||
// newTestLimiter makes a new limiter with the following configuration:
|
||||
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
|
||||
func newTestLimiter(t *testing.T) (*Limiter, clock.FakeClock) {
|
||||
clk := clock.NewFake()
|
||||
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "")
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
return l, clk
|
||||
}
|
||||
|
||||
// newTestLimiterWithOverrides makes a new limiter with the following
|
||||
// configuration:
|
||||
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
|
||||
// - 'NewRegistrationsPerIPAddress:10.0.0.2' burst: 40 count: 40 period: 1s
|
||||
func newTestLimiterWithOverrides(t *testing.T) (*Limiter, clock.FakeClock) {
|
||||
clk := clock.NewFake()
|
||||
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "testdata/working_override.yml")
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
return l, clk
|
||||
}
|
||||
|
||||
func Test_Limiter_initialization_via_Check_and_Spend(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
|
||||
// Check on an empty bucket should initialize it and return the theoretical
|
||||
// next state of that bucket if the cost were spent.
|
||||
d, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(19))
|
||||
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
|
||||
// 1000/20 = 50 milliseconds per request.
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
|
||||
// However, that cost should not be spent yet, a 0 cost check should tell us
|
||||
// that we actually have 20 remaining.
|
||||
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(20))
|
||||
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
|
||||
// Reset our bucket.
|
||||
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
|
||||
// Similar to above, but we'll use Spend() instead of Check() to initialize
|
||||
// the bucket. Spend should return the same result as Check.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(19))
|
||||
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
|
||||
// 1000/20 = 50 milliseconds per request.
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
|
||||
// However, that cost should not be spent yet, a 0 cost check should tell us
|
||||
// that we actually have 19 remaining.
|
||||
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(19))
|
||||
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||
// 1000/20 = 50 milliseconds per request.
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
}
|
||||
|
||||
func Test_Limiter_Refund_and_Spend_cost_err(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
|
||||
// Spend a cost of 0, which should fail.
|
||||
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||
|
||||
// Spend a negative cost, which should fail.
|
||||
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||
|
||||
// Refund a cost of 0, which should fail.
|
||||
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||
|
||||
// Refund a negative cost, which should fail.
|
||||
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||
}
|
||||
|
||||
func Test_Limiter_with_bad_limits_path(t *testing.T) {
|
||||
_, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "")
|
||||
test.AssertError(t, err, "should error")
|
||||
|
||||
_, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml")
|
||||
test.AssertError(t, err, "should error")
|
||||
}
|
||||
|
||||
func Test_Limiter_Check_bad_cost(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCostForCheck)
|
||||
}
|
||||
|
||||
func Test_Limiter_Check_limit_no_exist(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
_, err := l.Check(Name(9999), tenZeroZeroOne, 1)
|
||||
test.AssertError(t, err, "should error")
|
||||
}
|
||||
|
||||
func Test_Limiter_getLimit_no_exist(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
_, err := l.getLimit(Name(9999), "")
|
||||
test.AssertError(t, err, "should error")
|
||||
}
|
||||
|
||||
func Test_Limiter_with_defaults(t *testing.T) {
|
||||
l, clk := newTestLimiter(t)
|
||||
|
||||
// Attempt to spend 21 requests (a cost > the limit burst capacity), this
|
||||
// should fail with a specific error.
|
||||
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 21)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||
|
||||
// Attempt to spend all 20 requests, this should succeed.
|
||||
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Attempting to spend 1 more, this should fail.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||
// 1000/20 = 50 milliseconds per request.
|
||||
test.AssertEquals(t, d.RetryIn, time.Millisecond*50)
|
||||
|
||||
// Wait 50 milliseconds and try again.
|
||||
clk.Add(d.RetryIn)
|
||||
|
||||
// We should be allowed to spend 1 more request.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Wait 1 second for a full bucket reset.
|
||||
clk.Add(d.ResetIn)
|
||||
|
||||
// Quickly spend 20 requests in a row.
|
||||
for i := 0; i < 20; i++ {
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(19-i))
|
||||
}
|
||||
|
||||
// Attempting to spend 1 more, this should fail.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
}
|
||||
|
||||
func Test_Limiter_with_limit_overrides(t *testing.T) {
|
||||
l, clk := newTestLimiterWithOverrides(t)
|
||||
|
||||
// Attempt to check a spend of 41 requests (a cost > the limit burst
|
||||
// capacity), this should fail with a specific error.
|
||||
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||
|
||||
// Attempt to spend 41 requests (a cost > the limit burst capacity), this
|
||||
// should fail with a specific error.
|
||||
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
|
||||
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||
|
||||
// Attempt to spend all 40 requests, this should succeed.
|
||||
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
|
||||
// Attempting to spend 1 more, this should fail.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||
// 1000/40 = 25 milliseconds per request.
|
||||
test.AssertEquals(t, d.RetryIn, time.Millisecond*25)
|
||||
|
||||
// Wait 50 milliseconds and try again.
|
||||
clk.Add(d.RetryIn)
|
||||
|
||||
// We should be allowed to spend 1 more request.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Wait 1 second for a full bucket reset.
|
||||
clk.Add(d.ResetIn)
|
||||
|
||||
// Quickly spend 40 requests in a row.
|
||||
for i := 0; i < 40; i++ {
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(39-i))
|
||||
}
|
||||
|
||||
// Attempting to spend 1 more, this should fail.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
}
|
||||
|
||||
func Test_Limiter_with_new_clients(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
|
||||
// Attempt to spend all 20 requests, this should succeed.
|
||||
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Another new client, spend 1 and check our remaining.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, "10.0.0.100", 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(19))
|
||||
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||
|
||||
// 1 second == 1000 milliseconds and 1000/20 = 50 milliseconds per request.
|
||||
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||
}
|
||||
|
||||
func Test_Limiter_Refund_and_Reset(t *testing.T) {
|
||||
l, clk := newTestLimiter(t)
|
||||
|
||||
// Attempt to spend all 20 requests, this should succeed.
|
||||
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Refund 10 requests.
|
||||
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.AssertEquals(t, d.Remaining, int64(10))
|
||||
|
||||
// Spend 10 requests, this should succeed.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
|
||||
// Attempt to spend 20 more requests, this should succeed.
|
||||
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.Assert(t, d.Allowed, "should be allowed")
|
||||
test.AssertEquals(t, d.Remaining, int64(0))
|
||||
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||
|
||||
// Reset to full.
|
||||
clk.Add(d.ResetIn)
|
||||
|
||||
// Refund 1 requests above our limit, this should fail.
|
||||
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertErrorIs(t, err, ErrBucketAlreadyFull)
|
||||
test.AssertEquals(t, d.Remaining, int64(20))
|
||||
}
|
||||
|
||||
func Test_Limiter_Check_Spend_parity(t *testing.T) {
|
||||
il, _ := newTestLimiter(t)
|
||||
jl, _ := newTestLimiter(t)
|
||||
i, err := il.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
j, err := jl.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||
test.AssertNotError(t, err, "should not error")
|
||||
test.AssertDeepEquals(t, i.Remaining, j.Remaining)
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/letsencrypt/boulder/policy"
|
||||
)
|
||||
|
||||
// Name is an enumeration of all rate limit names. It is used to intern rate
|
||||
// limit names as strings and to provide a type-safe way to refer to rate
|
||||
// limits.
|
||||
//
|
||||
// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString'
|
||||
// mapping and idValidForName function below.
|
||||
type Name int
|
||||
|
||||
const (
|
||||
// Unknown is the zero value of Name and is used to indicate an unknown
|
||||
// limit name.
|
||||
Unknown Name = iota
|
||||
|
||||
// NewRegistrationsPerIPAddress uses bucket key 'enum:ipAddress'.
|
||||
NewRegistrationsPerIPAddress
|
||||
|
||||
// NewRegistrationsPerIPv6Range uses bucket key 'enum:ipv6rangeCIDR'. The
|
||||
// address range must be a /48.
|
||||
NewRegistrationsPerIPv6Range
|
||||
|
||||
// NewOrdersPerAccount uses bucket key 'enum:regId'.
|
||||
NewOrdersPerAccount
|
||||
|
||||
// FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId
|
||||
// is the registration id of the account.
|
||||
FailedAuthorizationsPerAccount
|
||||
|
||||
// CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain',
|
||||
// where name is the a name in a certificate issued to the account matching
|
||||
// regId.
|
||||
CertificatesPerDomainPerAccount
|
||||
|
||||
// CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet',
|
||||
// where nameSet is a set of names in a certificate issued to the account
|
||||
// matching regId.
|
||||
CertificatesPerFQDNSetPerAccount
|
||||
)
|
||||
|
||||
// nameToString is a map of Name values to string names.
|
||||
var nameToString = map[Name]string{
|
||||
Unknown: "Unknown",
|
||||
NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress",
|
||||
NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range",
|
||||
NewOrdersPerAccount: "NewOrdersPerAccount",
|
||||
FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount",
|
||||
CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount",
|
||||
CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount",
|
||||
}
|
||||
|
||||
// validIPAddress validates that the provided string is a valid IP address.
|
||||
func validIPAddress(id string) error {
|
||||
ip := net.ParseIP(id)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address, %q must be an IP address", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validIPv6RangeCIDR validates that the provided string is formatted is an IPv6
|
||||
// CIDR range with a /48 mask.
|
||||
func validIPv6RangeCIDR(id string) error {
|
||||
_, ipNet, err := net.ParseCIDR(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid CIDR, %q must be an IPv6 CIDR range", id)
|
||||
}
|
||||
ones, _ := ipNet.Mask.Size()
|
||||
if ones != 48 {
|
||||
// This also catches the case where the range is an IPv4 CIDR, since an
|
||||
// IPv4 CIDR can't have a /48 subnet mask - the maximum is /32.
|
||||
return fmt.Errorf(
|
||||
"invalid CIDR, %q must be /48", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRegId validates that the provided string is a valid ACME regId.
|
||||
func validateRegId(id string) error {
|
||||
_, err := strconv.ParseUint(id, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid regId, %q must be an ACME registration Id", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRegIdDomain validates that the provided string is formatted
|
||||
// 'regId:domain', where regId is an ACME registration Id and domain is a single
|
||||
// domain name.
|
||||
func validateRegIdDomain(id string) error {
|
||||
parts := strings.SplitN(id, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf(
|
||||
"invalid regId:domain, %q must be formatted 'regId:domain'", id)
|
||||
}
|
||||
if validateRegId(parts[0]) != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid regId, %q must be formatted 'regId:domain'", id)
|
||||
}
|
||||
if policy.ValidDomain(parts[1]) != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid domain, %q must be formatted 'regId:domain'", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRegIdFQDNSet validates that the provided string is formatted
|
||||
// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a
|
||||
// comma-separated list of domain names.
|
||||
func validateRegIdFQDNSet(id string) error {
|
||||
parts := strings.SplitN(id, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf(
|
||||
"invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
|
||||
}
|
||||
if validateRegId(parts[0]) != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid regId, %q must be formatted 'regId:fqdnSet'", id)
|
||||
}
|
||||
domains := strings.Split(parts[1], ",")
|
||||
if len(domains) == 0 {
|
||||
return fmt.Errorf(
|
||||
"invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if policy.ValidDomain(domain) != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid domain, %q must be formatted 'regId:fqdnSet'", id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIdForName(name Name, id string) error {
|
||||
switch name {
|
||||
case NewRegistrationsPerIPAddress:
|
||||
// 'enum:ipaddress'
|
||||
return validIPAddress(id)
|
||||
|
||||
case NewRegistrationsPerIPv6Range:
|
||||
// 'enum:ipv6rangeCIDR'
|
||||
return validIPv6RangeCIDR(id)
|
||||
|
||||
case NewOrdersPerAccount, FailedAuthorizationsPerAccount:
|
||||
// 'enum:regId'
|
||||
return validateRegId(id)
|
||||
|
||||
case CertificatesPerDomainPerAccount:
|
||||
// 'enum:regId:domain'
|
||||
return validateRegIdDomain(id)
|
||||
|
||||
case CertificatesPerFQDNSetPerAccount:
|
||||
// 'enum:regId:fqdnSet'
|
||||
return validateRegIdFQDNSet(id)
|
||||
|
||||
case Unknown:
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
// This should never happen.
|
||||
return fmt.Errorf("unknown limit enum %q", name)
|
||||
}
|
||||
}
|
||||
|
||||
// stringToName is a map of string names to Name values.
|
||||
var stringToName = func() map[string]Name {
|
||||
m := make(map[string]Name, len(nameToString))
|
||||
for k, v := range nameToString {
|
||||
m[v] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
// limitNames is a slice of all rate limit names.
|
||||
var limitNames = func() []string {
|
||||
names := make([]string, len(nameToString))
|
||||
for _, v := range nameToString {
|
||||
names = append(names, v)
|
||||
}
|
||||
return names
|
||||
}()
|
||||
|
||||
// nameToEnumString converts the integer value of the Name enumeration to its
|
||||
// string representation.
|
||||
func nameToEnumString(s Name) string {
|
||||
return strconv.Itoa(int(s))
|
||||
}
|
||||
|
||||
// bucketKey returns the key used to store a rate limit bucket.
|
||||
func bucketKey(name Name, id string) string {
|
||||
return nameToEnumString(name) + ":" + id
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package ratelimits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrBucketNotFound indicates that the bucket was not found.
|
||||
var ErrBucketNotFound = fmt.Errorf("bucket not found")
|
||||
|
||||
// source is an interface for creating and modifying TATs.
|
||||
type source interface {
|
||||
// Set stores the TAT at the specified bucketKey ('name:id').
|
||||
Set(bucketKey string, tat time.Time) error
|
||||
|
||||
// Get retrieves the TAT at the specified bucketKey ('name:id').
|
||||
Get(bucketKey string) (time.Time, error)
|
||||
|
||||
// Delete deletes the TAT at the specified bucketKey ('name:id').
|
||||
Delete(bucketKey string) error
|
||||
}
|
||||
|
||||
// inmem is an in-memory implementation of the source interface used for
|
||||
// testing.
|
||||
type inmem struct {
|
||||
sync.RWMutex
|
||||
m map[string]time.Time
|
||||
}
|
||||
|
||||
func newInmem() *inmem {
|
||||
return &inmem{m: make(map[string]time.Time)}
|
||||
}
|
||||
|
||||
func (in *inmem) Set(bucketKey string, tat time.Time) error {
|
||||
in.Lock()
|
||||
defer in.Unlock()
|
||||
in.m[bucketKey] = tat
|
||||
return nil
|
||||
}
|
||||
|
||||
func (in *inmem) Get(bucketKey string) (time.Time, error) {
|
||||
in.RLock()
|
||||
defer in.RUnlock()
|
||||
tat, ok := in.m[bucketKey]
|
||||
if !ok {
|
||||
return time.Time{}, ErrBucketNotFound
|
||||
}
|
||||
return tat, nil
|
||||
}
|
||||
|
||||
func (in *inmem) Delete(bucketKey string) error {
|
||||
in.Lock()
|
||||
defer in.Unlock()
|
||||
delete(in.m, bucketKey)
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
NewRegistrationsPerIPAddress:
|
||||
burst: 0
|
||||
count: 20
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
"":
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
UsageRequestsPerIPv10Address:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
|
@ -0,0 +1,8 @@
|
|||
NewRegistrationsPerIPAddress:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
||||
UsageRequestsPerIPv10Address:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 0
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
"UsageRequestsPerIPv10Address:":
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
":10.0.0.2":
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
UsageRequestsPerIPv10Address:10.0.0.2:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,8 @@
|
|||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
||||
UsageRequestsPerIPv10Address:10.0.0.5:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,12 @@
|
|||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
||||
NewRegistrationsPerIPAddress:10.0.0.5:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
||||
NewRegistrationsPerIPAddress:lol:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
NewRegistrationsPerIPAddress:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
|
@ -0,0 +1,8 @@
|
|||
NewRegistrationsPerIPAddress:
|
||||
burst: 20
|
||||
count: 20
|
||||
period: 1s
|
||||
NewRegistrationsPerIPv6Range:
|
||||
burst: 30
|
||||
count: 30
|
||||
period: 2s
|
|
@ -0,0 +1,4 @@
|
|||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,4 @@
|
|||
CertificatesPerDomainPerAccount:12345678:example.com:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
|
@ -0,0 +1,8 @@
|
|||
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
||||
NewRegistrationsPerIPv6Range:2001:0db8:0000::/48:
|
||||
burst: 50
|
||||
count: 50
|
||||
period: 2s
|
|
@ -0,0 +1,12 @@
|
|||
CertificatesPerFQDNSetPerAccount:12345678:example.com:
|
||||
burst: 40
|
||||
count: 40
|
||||
period: 1s
|
||||
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net:
|
||||
burst: 50
|
||||
count: 50
|
||||
period: 2s
|
||||
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org:
|
||||
burst: 60
|
||||
count: 60
|
||||
period: 3s
|
13
sa/model.go
13
sa/model.go
|
@ -12,7 +12,6 @@ import (
|
|||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
@ -873,14 +872,6 @@ type crlEntryModel struct {
|
|||
RevokedDate time.Time `db:"revokedDate"`
|
||||
}
|
||||
|
||||
// HashNames returns a hash of the names requested. This is intended for use
|
||||
// when interacting with the orderFqdnSets table.
|
||||
func HashNames(names []string) []byte {
|
||||
names = core.UniqueLowerNames(names)
|
||||
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
// orderFQDNSet contains the SHA256 hash of the lowercased, comma joined names
|
||||
// from a new-order request, along with the corresponding orderID, the
|
||||
// registration ID, and the order expiry. This is used to find
|
||||
|
@ -895,7 +886,7 @@ type orderFQDNSet struct {
|
|||
|
||||
func addFQDNSet(ctx context.Context, db db.Inserter, names []string, serial string, issued time.Time, expires time.Time) error {
|
||||
return db.Insert(ctx, &core.FQDNSet{
|
||||
SetHash: HashNames(names),
|
||||
SetHash: core.HashNames(names),
|
||||
Serial: serial,
|
||||
Issued: issued,
|
||||
Expires: expires,
|
||||
|
@ -914,7 +905,7 @@ func addOrderFQDNSet(
|
|||
regID int64,
|
||||
expires time.Time) error {
|
||||
return db.Insert(ctx, &orderFQDNSet{
|
||||
SetHash: HashNames(names),
|
||||
SetHash: core.HashNames(names),
|
||||
OrderID: orderID,
|
||||
RegistrationID: regID,
|
||||
Expires: expires,
|
||||
|
|
|
@ -2893,33 +2893,6 @@ func TestBlockedKeyRevokedBy(t *testing.T) {
|
|||
test.AssertNotError(t, err, "AddBlockedKey failed")
|
||||
}
|
||||
|
||||
func TestHashNames(t *testing.T) {
|
||||
// Test that it is deterministic
|
||||
h1 := HashNames([]string{"a"})
|
||||
h2 := HashNames([]string{"a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it differentiates
|
||||
h1 = HashNames([]string{"a"})
|
||||
h2 = HashNames([]string{"b"})
|
||||
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
|
||||
|
||||
// Test that it is not subject to ordering
|
||||
h1 = HashNames([]string{"a", "b"})
|
||||
h2 = HashNames([]string{"b", "a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it is not subject to case
|
||||
h1 = HashNames([]string{"a", "b"})
|
||||
h2 = HashNames([]string{"A", "B"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
|
||||
// Test that it is not subject to duplication
|
||||
h1 = HashNames([]string{"a", "a"})
|
||||
h2 = HashNames([]string{"a"})
|
||||
test.AssertByteEquals(t, h1, h2)
|
||||
}
|
||||
|
||||
func TestIncidentsForSerial(t *testing.T) {
|
||||
sa, _, cleanUp := initSA(t)
|
||||
defer cleanUp()
|
||||
|
|
|
@ -516,7 +516,7 @@ func (ssa *SQLStorageAuthorityRO) CountFQDNSets(ctx context.Context, req *sapb.C
|
|||
`SELECT COUNT(*) FROM fqdnSets
|
||||
WHERE setHash = ?
|
||||
AND issued > ?`,
|
||||
HashNames(req.Domains),
|
||||
core.HashNames(req.Domains),
|
||||
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
||||
)
|
||||
return &sapb.Count{Count: count}, err
|
||||
|
@ -544,7 +544,7 @@ func (ssa *SQLStorageAuthorityRO) FQDNSetTimestampsForWindow(ctx context.Context
|
|||
WHERE setHash = ?
|
||||
AND issued > ?
|
||||
ORDER BY issued DESC`,
|
||||
HashNames(req.Domains),
|
||||
core.HashNames(req.Domains),
|
||||
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -586,7 +586,7 @@ type oneSelectorFunc func(ctx context.Context, holder interface{}, query string,
|
|||
// checkFQDNSetExists uses the given oneSelectorFunc to check whether an fqdnSet
|
||||
// for the given names exists.
|
||||
func (ssa *SQLStorageAuthorityRO) checkFQDNSetExists(ctx context.Context, selector oneSelectorFunc, names []string) (bool, error) {
|
||||
namehash := HashNames(names)
|
||||
namehash := core.HashNames(names)
|
||||
var exists bool
|
||||
err := selector(
|
||||
ctx,
|
||||
|
@ -761,7 +761,7 @@ func (ssa *SQLStorageAuthorityRO) GetOrderForNames(ctx context.Context, req *sap
|
|||
}
|
||||
|
||||
// Hash the names requested for lookup in the orderFqdnSets table
|
||||
fqdnHash := HashNames(req.Names)
|
||||
fqdnHash := core.HashNames(req.Names)
|
||||
|
||||
// Find a possibly-suitable order. We don't include the account ID or order
|
||||
// status in this query because there's no index that includes those, so
|
||||
|
|
Loading…
Reference in New Issue