Initial implementation of key-value rate limits (#6947)
This design seeks to reduce read-pressure on our DB by moving rate limit tabulation to a key-value datastore. This PR provides the following: - (README.md) a short guide to the schemas, formats, and concepts introduced in this PR - (source.go) an interface for storing, retrieving, and resetting a subscriber bucket - (name.go) an enumeration of all defined rate limits - (limit.go) a schema for defining default limits and per-subscriber overrides - (limiter.go) a high-level API for interacting with key-value rate limits - (gcra.go) an implementation of the Generic Cell Rate Algorithm, a leaky bucket-style scheduling algorithm, used to calculate the present or future capacity of a subscriber bucket using spend and refund operations Note: the included source implementation is test-only and currently accomplished using a simple in-memory map protected by a mutex, implementations using Redis and potentially other data stores will follow. Part of #5545
This commit is contained in:
parent
e955494955
commit
055f620c4b
|
@ -304,7 +304,7 @@ func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mailer) certIsRenewed(ctx context.Context, names []string, issued time.Time) (bool, error) {
|
func (m *mailer) certIsRenewed(ctx context.Context, names []string, issued time.Time) (bool, error) {
|
||||||
namehash := sa.HashNames(names)
|
namehash := core.HashNames(names)
|
||||||
|
|
||||||
var present bool
|
var present bool
|
||||||
err := m.dbMap.SelectOne(
|
err := m.dbMap.SelectOne(
|
||||||
|
|
|
@ -353,7 +353,7 @@ func TestNoContactCertIsRenewed(t *testing.T) {
|
||||||
setupDBMap, err := sa.DBMapForTest(vars.DBConnSAFullPerms)
|
setupDBMap, err := sa.DBMapForTest(vars.DBConnSAFullPerms)
|
||||||
test.AssertNotError(t, err, "setting up DB")
|
test.AssertNotError(t, err, "setting up DB")
|
||||||
err = setupDBMap.Insert(ctx, &core.FQDNSet{
|
err = setupDBMap.Insert(ctx, &core.FQDNSet{
|
||||||
SetHash: sa.HashNames(names),
|
SetHash: core.HashNames(names),
|
||||||
Serial: core.SerialToString(serial2),
|
Serial: core.SerialToString(serial2),
|
||||||
Issued: testCtx.fc.Now().Add(time.Hour),
|
Issued: testCtx.fc.Now().Add(time.Hour),
|
||||||
Expires: expires.Add(time.Hour),
|
Expires: expires.Add(time.Hour),
|
||||||
|
@ -580,13 +580,13 @@ func addExpiringCerts(t *testing.T, ctx *testCtx) []certDERWithRegID {
|
||||||
test.AssertNotError(t, err, "creating cert D")
|
test.AssertNotError(t, err, "creating cert D")
|
||||||
|
|
||||||
fqdnStatusD := &core.FQDNSet{
|
fqdnStatusD := &core.FQDNSet{
|
||||||
SetHash: sa.HashNames(certDNames),
|
SetHash: core.HashNames(certDNames),
|
||||||
Serial: serial4String,
|
Serial: serial4String,
|
||||||
Issued: ctx.fc.Now().AddDate(0, 0, -87),
|
Issued: ctx.fc.Now().AddDate(0, 0, -87),
|
||||||
Expires: ctx.fc.Now().AddDate(0, 0, 3),
|
Expires: ctx.fc.Now().AddDate(0, 0, 3),
|
||||||
}
|
}
|
||||||
fqdnStatusDRenewed := &core.FQDNSet{
|
fqdnStatusDRenewed := &core.FQDNSet{
|
||||||
SetHash: sa.HashNames(certDNames),
|
SetHash: core.HashNames(certDNames),
|
||||||
Serial: serial5String,
|
Serial: serial5String,
|
||||||
Issued: ctx.fc.Now().AddDate(0, 0, -3),
|
Issued: ctx.fc.Now().AddDate(0, 0, -3),
|
||||||
Expires: ctx.fc.Now().AddDate(0, 0, 87),
|
Expires: ctx.fc.Now().AddDate(0, 0, 87),
|
||||||
|
@ -747,7 +747,7 @@ func TestCertIsRenewed(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
fqdnStatus := &core.FQDNSet{
|
fqdnStatus := &core.FQDNSet{
|
||||||
SetHash: sa.HashNames(testData.DNS),
|
SetHash: core.HashNames(testData.DNS),
|
||||||
Serial: testData.stringSerial,
|
Serial: testData.stringSerial,
|
||||||
Issued: testData.NotBefore,
|
Issued: testData.NotBefore,
|
||||||
Expires: testData.NotAfter,
|
Expires: testData.NotAfter,
|
||||||
|
|
|
@ -242,6 +242,14 @@ func UniqueLowerNames(names []string) (unique []string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HashNames returns a hash of the names requested. This is intended for use
|
||||||
|
// when interacting with the orderFqdnSets table and rate limiting.
|
||||||
|
func HashNames(names []string) []byte {
|
||||||
|
names = UniqueLowerNames(names)
|
||||||
|
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
|
||||||
|
return hash[:]
|
||||||
|
}
|
||||||
|
|
||||||
// LoadCert loads a PEM certificate specified by filename or returns an error
|
// LoadCert loads a PEM certificate specified by filename or returns an error
|
||||||
func LoadCert(filename string) (*x509.Certificate, error) {
|
func LoadCert(filename string) (*x509.Certificate, error) {
|
||||||
certPEM, err := os.ReadFile(filename)
|
certPEM, err := os.ReadFile(filename)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
@ -206,3 +207,30 @@ func TestRetryBackoff(t *testing.T) {
|
||||||
assertBetween(float64(backoff), float64(expected)*0.8, float64(expected)*1.2)
|
assertBetween(float64(backoff), float64(expected)*0.8, float64(expected)*1.2)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHashNames(t *testing.T) {
|
||||||
|
// Test that it is deterministic
|
||||||
|
h1 := HashNames([]string{"a"})
|
||||||
|
h2 := HashNames([]string{"a"})
|
||||||
|
test.AssertByteEquals(t, h1, h2)
|
||||||
|
|
||||||
|
// Test that it differentiates
|
||||||
|
h1 = HashNames([]string{"a"})
|
||||||
|
h2 = HashNames([]string{"b"})
|
||||||
|
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
|
||||||
|
|
||||||
|
// Test that it is not subject to ordering
|
||||||
|
h1 = HashNames([]string{"a", "b"})
|
||||||
|
h2 = HashNames([]string{"b", "a"})
|
||||||
|
test.AssertByteEquals(t, h1, h2)
|
||||||
|
|
||||||
|
// Test that it is not subject to case
|
||||||
|
h1 = HashNames([]string{"a", "b"})
|
||||||
|
h2 = HashNames([]string{"A", "B"})
|
||||||
|
test.AssertByteEquals(t, h1, h2)
|
||||||
|
|
||||||
|
// Test that it is not subject to duplication
|
||||||
|
h1 = HashNames([]string{"a", "a"})
|
||||||
|
h2 = HashNames([]string{"a"})
|
||||||
|
test.AssertByteEquals(t, h1, h2)
|
||||||
|
}
|
||||||
|
|
|
@ -196,7 +196,7 @@ var (
|
||||||
errWildcardNotSupported = berrors.MalformedError("Wildcard domain names are not supported")
|
errWildcardNotSupported = berrors.MalformedError("Wildcard domain names are not supported")
|
||||||
)
|
)
|
||||||
|
|
||||||
// validDomain checks that a domain isn't:
|
// ValidDomain checks that a domain isn't:
|
||||||
//
|
//
|
||||||
// * empty
|
// * empty
|
||||||
// * prefixed with the wildcard label `*.`
|
// * prefixed with the wildcard label `*.`
|
||||||
|
@ -210,7 +210,7 @@ var (
|
||||||
// * exactly equal to an IANA registered TLD
|
// * exactly equal to an IANA registered TLD
|
||||||
//
|
//
|
||||||
// It does _not_ check that the domain isn't on any PA blocked lists.
|
// It does _not_ check that the domain isn't on any PA blocked lists.
|
||||||
func validDomain(domain string) error {
|
func ValidDomain(domain string) error {
|
||||||
if domain == "" {
|
if domain == "" {
|
||||||
return errEmptyName
|
return errEmptyName
|
||||||
}
|
}
|
||||||
|
@ -323,7 +323,7 @@ func ValidEmail(address string) error {
|
||||||
}
|
}
|
||||||
splitEmail := strings.SplitN(email.Address, "@", -1)
|
splitEmail := strings.SplitN(email.Address, "@", -1)
|
||||||
domain := strings.ToLower(splitEmail[len(splitEmail)-1])
|
domain := strings.ToLower(splitEmail[len(splitEmail)-1])
|
||||||
err = validDomain(domain)
|
err = ValidDomain(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return berrors.InvalidEmailError(
|
return berrors.InvalidEmailError(
|
||||||
"contact email %q has invalid domain : %s",
|
"contact email %q has invalid domain : %s",
|
||||||
|
@ -363,7 +363,7 @@ func (pa *AuthorityImpl) willingToIssue(id identifier.ACMEIdentifier) error {
|
||||||
}
|
}
|
||||||
domain := id.Value
|
domain := id.Value
|
||||||
|
|
||||||
err := validDomain(domain)
|
err := ValidDomain(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,190 @@
|
||||||
|
# Configuring and Storing Key-Value Rate Limits
|
||||||
|
|
||||||
|
## Rate Limit Structure
|
||||||
|
|
||||||
|
All rate limits use a token-bucket model. The metaphor is that each limit is
|
||||||
|
represented by a bucket which holds tokens. Each request removes some number of
|
||||||
|
tokens from the bucket, or is denied if there aren't enough tokens to remove.
|
||||||
|
Over time, new tokens are added to the bucket at a steady rate, until the bucket
|
||||||
|
is full. The _burst_ parameter of a rate limit indicates the maximum capacity of
|
||||||
|
a bucket: how many tokens can it hold before new ones stop being added.
|
||||||
|
Therefore, this also indicates how many requests can be made in a single burst
|
||||||
|
before a full bucket is completely emptied. The _count_ and _period_ parameters
|
||||||
|
indicate the rate at which new tokens are added to a bucket: every period, count
|
||||||
|
tokens will be added. Therefore, these also indicate the steady-state rate at
|
||||||
|
which a client which has exhausted its quota can make requests: one token every
|
||||||
|
(period / count) duration.
|
||||||
|
|
||||||
|
## Default Limit Settings
|
||||||
|
|
||||||
|
Each key directly corresponds to a `Name` enumeration as detailed in `//ratelimits/names.go`.
|
||||||
|
The `Name` enum is used to identify the particular limit. The parameters of a
|
||||||
|
default limit are the values that will be used for all buckets that do not have
|
||||||
|
an explicit override (see below).
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
NewRegistrationsPerIPAddress:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
||||||
|
NewOrdersPerAccount:
|
||||||
|
burst: 300
|
||||||
|
count: 300
|
||||||
|
period: 180m
|
||||||
|
```
|
||||||
|
|
||||||
|
## Override Limit Settings
|
||||||
|
|
||||||
|
Each override key represents a specific bucket, consisting of two elements:
|
||||||
|
_name_ and _id_. The name here refers to the Name of the particular limit, while
|
||||||
|
the id is a client identifier. The format of the id is dependent on the limit.
|
||||||
|
For example, the id for 'NewRegistrationsPerIPAddress' is a subscriber IP
|
||||||
|
address, while the id for 'NewOrdersPerAccount' is the subscriber's registration
|
||||||
|
ID.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 20
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
NewOrdersPerAccount:12345678:
|
||||||
|
burst: 300
|
||||||
|
count: 600
|
||||||
|
period: 180m
|
||||||
|
```
|
||||||
|
|
||||||
|
The above example overrides the default limits for specific subscribers. In both
|
||||||
|
cases the count of requests per period are doubled, but the burst capacity is
|
||||||
|
explicitly configured to match the default rate limit.
|
||||||
|
|
||||||
|
### Id Formats in Limit Override Settings
|
||||||
|
|
||||||
|
Id formats vary based on the `Name` enumeration. Below are examples for each
|
||||||
|
format:
|
||||||
|
|
||||||
|
#### ipAddress
|
||||||
|
|
||||||
|
A valid IPv4 or IPv6 address.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- `NewRegistrationsPerIPAddress:10.0.0.1`
|
||||||
|
- `NewRegistrationsPerIPAddress:2001:0db8:0000:0000:0000:ff00:0042:8329`
|
||||||
|
|
||||||
|
#### ipv6RangeCIDR
|
||||||
|
|
||||||
|
A valid IPv6 range in CIDR notation with a /48 mask. A /48 range is typically
|
||||||
|
assigned to a single subscriber.
|
||||||
|
|
||||||
|
Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48`
|
||||||
|
|
||||||
|
#### regId
|
||||||
|
|
||||||
|
The registration ID of the account.
|
||||||
|
|
||||||
|
Example: `NewOrdersPerAccount:12345678`
|
||||||
|
|
||||||
|
#### regId:domain
|
||||||
|
|
||||||
|
A combination of registration ID and domain, formatted 'regId:domain'.
|
||||||
|
|
||||||
|
Example: `CertificatesPerDomainPerAccount:12345678:example.com`
|
||||||
|
|
||||||
|
#### regId:fqdnSet
|
||||||
|
|
||||||
|
A combination of registration ID and a comma-separated list of domain names,
|
||||||
|
formatted 'regId:fqdnSet'.
|
||||||
|
|
||||||
|
Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org`
|
||||||
|
|
||||||
|
## Bucket Key Definitions
|
||||||
|
|
||||||
|
A bucket key is used to lookup the bucket for a given limit and
|
||||||
|
subscriber. Bucket keys are formatted similarly to the overrides but with a
|
||||||
|
slight difference: the limit Names do not carry the string form of each limit.
|
||||||
|
Instead, they apply the `Name` enum equivalent for every limit.
|
||||||
|
|
||||||
|
So, instead of:
|
||||||
|
|
||||||
|
```
|
||||||
|
NewOrdersPerAccount:12345678
|
||||||
|
```
|
||||||
|
|
||||||
|
The corresponding bucket key for regId 12345678 would look like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
6:12345678
|
||||||
|
```
|
||||||
|
|
||||||
|
When loaded from a file, the keys for the default/override limits undergo the
|
||||||
|
same interning process as the aforementioned subscriber bucket keys. This
|
||||||
|
eliminates the need for redundant conversions when fetching each
|
||||||
|
default/override limit.
|
||||||
|
|
||||||
|
## How Limits are Applied
|
||||||
|
|
||||||
|
Although rate limit buckets are configured in terms of tokens, we do not
|
||||||
|
actually keep track of the number of tokens in each bucket. Instead, we track
|
||||||
|
the Theoretical Arrival Time (TAT) at which the bucket will be full again. If
|
||||||
|
the TAT is in the past, the bucket is full. If the TAT is in the future, some
|
||||||
|
number of tokens have been spent and the bucket is slowly refilling. If the TAT
|
||||||
|
is far enough in the future (specifically, more than `burst * (period / count)`)
|
||||||
|
in the future), then the bucket is completely empty and requests will be denied.
|
||||||
|
|
||||||
|
Additional terminology:
|
||||||
|
|
||||||
|
- **burst offset** is the duration of time it takes for a bucket to go from
|
||||||
|
empty to full (`burst * (period / count)`).
|
||||||
|
- **emission interval** is the interval at which tokens are added to a bucket
|
||||||
|
(`period / count`). This is also the steady-state rate at which requests can
|
||||||
|
be made without being denied even once the burst has been exhausted.
|
||||||
|
- **cost** is the number of tokens removed from a bucket for a single request.
|
||||||
|
- **cost increment** is the duration of time the TAT is advanced to account
|
||||||
|
for the cost of the request (`cost * emission interval`).
|
||||||
|
|
||||||
|
For the purposes of this example, subscribers originating from a specific IPv4
|
||||||
|
address are allowed 20 requests to the newFoo endpoint per second, with a
|
||||||
|
maximum burst of 20 requests at any point-in-time, or:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
NewFoosPerIPAddress:172.23.45.22:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
||||||
|
```
|
||||||
|
|
||||||
|
A subscriber calls the newFoo endpoint for the first time with an IP address of
|
||||||
|
172.23.45.22. Here's what happens:
|
||||||
|
|
||||||
|
1. The subscriber's IP address is used to generate a bucket key in the form of
|
||||||
|
'NewFoosPerIPAddress:172.23.45.22'.
|
||||||
|
|
||||||
|
2. The request is approved and the 'NewFoosPerIPAddress:172.23.45.22' bucket is
|
||||||
|
initialized with 19 tokens, as 1 token has been removed to account for the
|
||||||
|
cost of the current request. To accomplish this, the initial TAT is set to
|
||||||
|
the current time plus the _cost increment_ (which is 1/20th of a second if we
|
||||||
|
are limiting to 20 requests per second).
|
||||||
|
|
||||||
|
3. Bucket 'NewFoosPerIPAddress:172.23.45.22':
|
||||||
|
- will reset to full in 50ms (1/20th of a second),
|
||||||
|
- will allow another newFoo request immediately,
|
||||||
|
- will allow between 1 and 19 more requests in the next 50ms,
|
||||||
|
- will reject the 20th request made in the next 50ms,
|
||||||
|
- and will allow 1 request every 50ms, indefinitely.
|
||||||
|
|
||||||
|
The subscriber makes another request 5ms later:
|
||||||
|
|
||||||
|
4. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is compared against
|
||||||
|
the current time and the _burst offset_. The current time is greater than the
|
||||||
|
TAT minus the cost increment. Therefore, the request is approved.
|
||||||
|
|
||||||
|
5. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is advanced by the
|
||||||
|
cost increment to account for the cost of the request.
|
||||||
|
|
||||||
|
The subscriber makes a total of 18 requests over the next 44ms:
|
||||||
|
|
||||||
|
6. The current time is less than the TAT at bucket key
|
||||||
|
'NewFoosPerIPAddress:172.23.45.22' minus the burst offset, thus the request
|
||||||
|
is rejected.
|
||||||
|
|
||||||
|
This mechanism allows for bursts of traffic but also ensures that the average
|
||||||
|
rate of requests stays within the prescribed limits over time.
|
|
@ -0,0 +1,110 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmhodges/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// maybeSpend uses the GCRA algorithm to decide whether to allow a request. It
|
||||||
|
// returns a Decision struct with the result of the decision and the updated
|
||||||
|
// TAT. The cost must be 0 or greater and <= the burst capacity of the limit.
|
||||||
|
func maybeSpend(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
|
||||||
|
if cost < 0 || cost > rl.Burst {
|
||||||
|
// The condition above is the union of the conditions checked in Check
|
||||||
|
// and Spend methods of Limiter. If this panic is reached, it means that
|
||||||
|
// the caller has introduced a bug.
|
||||||
|
panic("invalid cost for maybeSpend")
|
||||||
|
}
|
||||||
|
nowUnix := clk.Now().UnixNano()
|
||||||
|
tatUnix := tat.UnixNano()
|
||||||
|
|
||||||
|
// If the TAT is in the future, use it as the starting point for the
|
||||||
|
// calculation. Otherwise, use the current time. This is to prevent the
|
||||||
|
// bucket from being filled with capacity from the past.
|
||||||
|
if nowUnix > tatUnix {
|
||||||
|
tatUnix = nowUnix
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the cost increment.
|
||||||
|
costIncrement := rl.emissionInterval * cost
|
||||||
|
|
||||||
|
// Deduct the cost to find the new TAT and residual capacity.
|
||||||
|
newTAT := tatUnix + costIncrement
|
||||||
|
difference := nowUnix - (newTAT - rl.burstOffset)
|
||||||
|
|
||||||
|
if difference < 0 {
|
||||||
|
// Too little capacity to satisfy the cost, deny the request.
|
||||||
|
residual := (nowUnix - (tatUnix - rl.burstOffset)) / rl.emissionInterval
|
||||||
|
return &Decision{
|
||||||
|
Allowed: false,
|
||||||
|
Remaining: residual,
|
||||||
|
RetryIn: -time.Duration(difference),
|
||||||
|
ResetIn: time.Duration(tatUnix - nowUnix),
|
||||||
|
newTAT: time.Unix(0, tatUnix).UTC(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is enough capacity to satisfy the cost, allow the request.
|
||||||
|
var retryIn time.Duration
|
||||||
|
residual := difference / rl.emissionInterval
|
||||||
|
if difference < costIncrement {
|
||||||
|
retryIn = time.Duration(costIncrement - difference)
|
||||||
|
}
|
||||||
|
return &Decision{
|
||||||
|
Allowed: true,
|
||||||
|
Remaining: residual,
|
||||||
|
RetryIn: retryIn,
|
||||||
|
ResetIn: time.Duration(newTAT - nowUnix),
|
||||||
|
newTAT: time.Unix(0, newTAT).UTC(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeRefund uses the Generic Cell Rate Algorithm (GCRA) to attempt to refund
|
||||||
|
// the cost of a request which was previously spent. The refund cost must be 0
|
||||||
|
// or greater. A cost will only be refunded up to the burst capacity of the
|
||||||
|
// limit. A partial refund is still considered successful.
|
||||||
|
func maybeRefund(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision {
|
||||||
|
if cost <= 0 || cost > rl.Burst {
|
||||||
|
// The condition above is checked in the Refund method of Limiter. If
|
||||||
|
// this panic is reached, it means that the caller has introduced a bug.
|
||||||
|
panic("invalid cost for maybeRefund")
|
||||||
|
}
|
||||||
|
nowUnix := clk.Now().UnixNano()
|
||||||
|
tatUnix := tat.UnixNano()
|
||||||
|
|
||||||
|
// The TAT must be in the future to refund capacity.
|
||||||
|
if nowUnix > tatUnix {
|
||||||
|
// The TAT is in the past, therefore the bucket is full.
|
||||||
|
return &Decision{
|
||||||
|
Allowed: false,
|
||||||
|
Remaining: rl.Burst,
|
||||||
|
RetryIn: time.Duration(0),
|
||||||
|
ResetIn: time.Duration(0),
|
||||||
|
newTAT: tat,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the refund increment.
|
||||||
|
refundIncrement := rl.emissionInterval * cost
|
||||||
|
|
||||||
|
// Subtract the refund increment from the TAT to find the new TAT.
|
||||||
|
newTAT := tatUnix - refundIncrement
|
||||||
|
|
||||||
|
// Ensure the new TAT is not earlier than now.
|
||||||
|
if newTAT < nowUnix {
|
||||||
|
newTAT = nowUnix
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the new capacity.
|
||||||
|
difference := nowUnix - (newTAT - rl.burstOffset)
|
||||||
|
residual := difference / rl.emissionInterval
|
||||||
|
|
||||||
|
return &Decision{
|
||||||
|
Allowed: (newTAT != tatUnix),
|
||||||
|
Remaining: residual,
|
||||||
|
RetryIn: time.Duration(0),
|
||||||
|
ResetIn: time.Duration(newTAT - nowUnix),
|
||||||
|
newTAT: time.Unix(0, newTAT).UTC(),
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,221 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmhodges/clock"
|
||||||
|
"github.com/letsencrypt/boulder/config"
|
||||||
|
"github.com/letsencrypt/boulder/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_decide(t *testing.T) {
|
||||||
|
clk := clock.NewFake()
|
||||||
|
limit := precomputeLimit(
|
||||||
|
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Begin by using 1 of our 10 requests.
|
||||||
|
d := maybeSpend(clk, limit, clk.Now(), 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(9))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Immediately use another 9 of our remaining requests.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 9)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
// We should have to wait 1 second before we can use another request but we
|
||||||
|
// used 9 so we should have to wait 9 seconds to make an identical request.
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second*9)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Our new TAT should be 10 seconds (limit.Burst) in the future.
|
||||||
|
test.AssertEquals(t, d.newTAT, clk.Now().Add(time.Second*10))
|
||||||
|
|
||||||
|
// Let's try using just 1 more request without waiting.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Let's try being exactly as patient as we're told to be.
|
||||||
|
clk.Add(d.RetryIn)
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(1))
|
||||||
|
|
||||||
|
// We are 1 second in the future, we should have 1 new request.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Let's try waiting (10 seconds) for our whole bucket to refill.
|
||||||
|
clk.Add(d.ResetIn)
|
||||||
|
|
||||||
|
// We should have 10 new requests. If we use 1 we should have 9 remaining.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(9))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Wait just shy of how long we're told to wait for refilling.
|
||||||
|
clk.Add(d.ResetIn - time.Millisecond)
|
||||||
|
|
||||||
|
// We should still have 9 remaining because we're still 1ms shy of the
|
||||||
|
// refill time.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(9))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond)
|
||||||
|
|
||||||
|
// Spending 0 simply informed us that we still have 9 remaining, let's see
|
||||||
|
// what we have after waiting 20 hours.
|
||||||
|
clk.Add(20 * time.Hour)
|
||||||
|
|
||||||
|
// C'mon, big money, no whammies, no whammies, STOP!
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Turns out that the most we can accrue is 10 (limit.Burst). Let's empty
|
||||||
|
// this bucket out so we can try something else.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 10)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
// We should have to wait 1 second before we can use another request but we
|
||||||
|
// used 10 so we should have to wait 10 seconds to make an identical
|
||||||
|
// request.
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second*10)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// If you spend 0 while you have 0 you should get 0.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// We don't play by the rules, we spend 1 when we have 0.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Okay, maybe we should play by the rules if we want to get anywhere.
|
||||||
|
clk.Add(d.RetryIn)
|
||||||
|
|
||||||
|
// Our patience pays off, we should have 1 new request. Let's use it.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Refill from empty to 5.
|
||||||
|
clk.Add(d.ResetIn / 2)
|
||||||
|
|
||||||
|
// Attempt to spend 7 when we only have 5. We should be denied but the
|
||||||
|
// decision should reflect a retry of 2 seconds, the time it would take to
|
||||||
|
// refill from 5 to 7.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 7)
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(5))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second*2)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_maybeRefund(t *testing.T) {
|
||||||
|
clk := clock.NewFake()
|
||||||
|
limit := precomputeLimit(
|
||||||
|
limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Begin by using 1 of our 10 requests.
|
||||||
|
d := maybeSpend(clk, limit, clk.Now(), 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(9))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Refund back to 10.
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Spend 1 more of our 10 requests.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(9))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Wait for our bucket to refill.
|
||||||
|
clk.Add(d.ResetIn)
|
||||||
|
|
||||||
|
// Attempt to refund from 10 to 11.
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Spend 10 all 10 of our requests.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 10)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
// We should have to wait 1 second before we can use another request but we
|
||||||
|
// used 10 so we should have to wait 10 seconds to make an identical
|
||||||
|
// request.
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Second*10)
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second*10)
|
||||||
|
|
||||||
|
// Attempt a refund of 10.
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 10)
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Wait 11 seconds to catching up to TAT.
|
||||||
|
clk.Add(11 * time.Second)
|
||||||
|
|
||||||
|
// Attempt to refund to 11, then ensure it's still 10.
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, !d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Spend 5 of our 10 requests, then refund 1.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 5)
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 1)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(6))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Wait, a 2.5 seconds to refill to 8.5 requests.
|
||||||
|
clk.Add(time.Millisecond * 2500)
|
||||||
|
|
||||||
|
// Ensure we have 8.5 requests.
|
||||||
|
d = maybeSpend(clk, limit, d.newTAT, 0)
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(8))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
// Check that ResetIn represents the fractional earned request.
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond*1500)
|
||||||
|
|
||||||
|
// Refund 2 requests, we should only have 10, not 10.5.
|
||||||
|
d = maybeRefund(clk, limit, d.newTAT, 2)
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
}
|
|
@ -0,0 +1,160 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/letsencrypt/boulder/config"
|
||||||
|
"github.com/letsencrypt/boulder/core"
|
||||||
|
"github.com/letsencrypt/boulder/strictyaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type limit struct {
|
||||||
|
// Burst specifies maximum concurrent allowed requests at any given time. It
|
||||||
|
// must be greater than zero.
|
||||||
|
Burst int64
|
||||||
|
|
||||||
|
// Count is the number of requests allowed per period. It must be greater
|
||||||
|
// than zero.
|
||||||
|
Count int64
|
||||||
|
|
||||||
|
// Period is the duration of time in which the count (of requests) is
|
||||||
|
// allowed. It must be greater than zero.
|
||||||
|
Period config.Duration
|
||||||
|
|
||||||
|
// emissionInterval is the interval, in nanoseconds, at which tokens are
|
||||||
|
// added to a bucket (period / count). This is also the steady-state rate at
|
||||||
|
// which requests can be made without being denied even once the burst has
|
||||||
|
// been exhausted. This is precomputed to avoid doing the same calculation
|
||||||
|
// on every request.
|
||||||
|
emissionInterval int64
|
||||||
|
|
||||||
|
// burstOffset is the duration of time, in nanoseconds, it takes for a
|
||||||
|
// bucket to go from empty to full (burst * (period / count)). This is
|
||||||
|
// precomputed to avoid doing the same calculation on every request.
|
||||||
|
burstOffset int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func precomputeLimit(l limit) limit {
|
||||||
|
l.emissionInterval = l.Period.Nanoseconds() / l.Count
|
||||||
|
l.burstOffset = l.emissionInterval * l.Burst
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateLimit(l limit) error {
|
||||||
|
if l.Burst <= 0 {
|
||||||
|
return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst)
|
||||||
|
}
|
||||||
|
if l.Count <= 0 {
|
||||||
|
return fmt.Errorf("invalid count '%d', must be > 0", l.Count)
|
||||||
|
}
|
||||||
|
if l.Period.Duration <= 0 {
|
||||||
|
return fmt.Errorf("invalid period '%s', must be > 0", l.Period)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type limits map[string]limit
|
||||||
|
|
||||||
|
// loadLimits marshals the YAML file at path into a map of limis.
|
||||||
|
func loadLimits(path string) (limits, error) {
|
||||||
|
lm := make(limits)
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = strictyaml.Unmarshal(data, &lm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return lm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOverrideNameId is broken out for ease of testing.
|
||||||
|
func parseOverrideNameId(key string) (Name, string, error) {
|
||||||
|
if !strings.Contains(key, ":") {
|
||||||
|
// Avoids a potential panic in strings.SplitN below.
|
||||||
|
return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
|
||||||
|
}
|
||||||
|
nameAndId := strings.SplitN(key, ":", 2)
|
||||||
|
nameStr := nameAndId[0]
|
||||||
|
if nameStr == "" {
|
||||||
|
return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
name, ok := stringToName[nameStr]
|
||||||
|
if !ok {
|
||||||
|
return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, limitNames)
|
||||||
|
}
|
||||||
|
id := nameAndId[1]
|
||||||
|
if id == "" {
|
||||||
|
return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key)
|
||||||
|
}
|
||||||
|
return name, id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndParseOverrideLimits loads override limits from YAML, validates them,
|
||||||
|
// and parses them into a map of limits keyed by 'Name:id'.
|
||||||
|
func loadAndParseOverrideLimits(path string) (limits, error) {
|
||||||
|
fromFile, err := loadLimits(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parsed := make(limits, len(fromFile))
|
||||||
|
|
||||||
|
for k, v := range fromFile {
|
||||||
|
err = validateLimit(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validating override limit %q: %w", k, err)
|
||||||
|
}
|
||||||
|
name, id, err := parseOverrideNameId(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing override limit %q: %w", k, err)
|
||||||
|
}
|
||||||
|
err = validateIdForName(name, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"validating name %s and id %q for override limit %q: %w", nameToString[name], id, k, err)
|
||||||
|
}
|
||||||
|
if name == CertificatesPerFQDNSetPerAccount {
|
||||||
|
// FQDNSet hashes are not a nice thing to ask for in a config file,
|
||||||
|
// so we allow the user to specify a comma-separated list of FQDNs
|
||||||
|
// and compute the hash here.
|
||||||
|
regIdDomains := strings.SplitN(id, ":", 2)
|
||||||
|
if len(regIdDomains) != 2 {
|
||||||
|
// Should never happen, the Id format was validated above.
|
||||||
|
return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k)
|
||||||
|
}
|
||||||
|
regId := regIdDomains[0]
|
||||||
|
domains := strings.Split(regIdDomains[1], ",")
|
||||||
|
fqdnSet := core.HashNames(domains)
|
||||||
|
id = fmt.Sprintf("%s:%s", regId, fqdnSet)
|
||||||
|
}
|
||||||
|
parsed[bucketKey(name, id)] = precomputeLimit(v)
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndParseDefaultLimits loads default limits from YAML, validates them, and
|
||||||
|
// parses them into a map of limits keyed by 'Name'.
|
||||||
|
func loadAndParseDefaultLimits(path string) (limits, error) {
|
||||||
|
fromFile, err := loadLimits(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parsed := make(limits, len(fromFile))
|
||||||
|
|
||||||
|
for k, v := range fromFile {
|
||||||
|
err := validateLimit(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing default limit %q: %w", k, err)
|
||||||
|
}
|
||||||
|
name, ok := stringToName[k]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames)
|
||||||
|
}
|
||||||
|
parsed[nameToEnumString(name)] = precomputeLimit(v)
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
|
@ -0,0 +1,333 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/letsencrypt/boulder/config"
|
||||||
|
"github.com/letsencrypt/boulder/core"
|
||||||
|
"github.com/letsencrypt/boulder/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseOverrideNameId(t *testing.T) {
|
||||||
|
newRegistrationsPerIPAddressStr := nameToString[NewRegistrationsPerIPAddress]
|
||||||
|
newRegistrationsPerIPv6RangeStr := nameToString[NewRegistrationsPerIPv6Range]
|
||||||
|
|
||||||
|
// 'enum:ipv4'
|
||||||
|
// Valid IPv4 address.
|
||||||
|
name, id, err := parseOverrideNameId(newRegistrationsPerIPAddressStr + ":10.0.0.1")
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.AssertEquals(t, name, NewRegistrationsPerIPAddress)
|
||||||
|
test.AssertEquals(t, id, "10.0.0.1")
|
||||||
|
|
||||||
|
// 'enum:ipv6range'
|
||||||
|
// Valid IPv6 address range.
|
||||||
|
name, id, err = parseOverrideNameId(newRegistrationsPerIPv6RangeStr + ":2001:0db8:0000::/48")
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.AssertEquals(t, name, NewRegistrationsPerIPv6Range)
|
||||||
|
test.AssertEquals(t, id, "2001:0db8:0000::/48")
|
||||||
|
|
||||||
|
// Missing colon (this should never happen but we should avoid panicking).
|
||||||
|
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + "10.0.0.1")
|
||||||
|
test.AssertError(t, err, "missing colon")
|
||||||
|
|
||||||
|
// Empty string.
|
||||||
|
_, _, err = parseOverrideNameId("")
|
||||||
|
test.AssertError(t, err, "empty string")
|
||||||
|
|
||||||
|
// Only a colon.
|
||||||
|
_, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + ":")
|
||||||
|
test.AssertError(t, err, "only a colon")
|
||||||
|
|
||||||
|
// Invalid enum.
|
||||||
|
_, _, err = parseOverrideNameId("lol:noexist")
|
||||||
|
test.AssertError(t, err, "invalid enum")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validateLimit(t *testing.T) {
|
||||||
|
err := validateLimit(limit{Burst: 1, Count: 1, Period: config.Duration{Duration: time.Second}})
|
||||||
|
test.AssertNotError(t, err, "valid limit")
|
||||||
|
|
||||||
|
// All of the following are invalid.
|
||||||
|
for _, l := range []limit{
|
||||||
|
{Burst: 0, Count: 1, Period: config.Duration{Duration: time.Second}},
|
||||||
|
{Burst: 1, Count: 0, Period: config.Duration{Duration: time.Second}},
|
||||||
|
{Burst: 1, Count: 1, Period: config.Duration{Duration: 0}},
|
||||||
|
} {
|
||||||
|
err = validateLimit(l)
|
||||||
|
test.AssertError(t, err, "limit should be invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validateIdForName(t *testing.T) {
|
||||||
|
// 'enum:ipAddress'
|
||||||
|
// Valid IPv4 address.
|
||||||
|
err := validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.1")
|
||||||
|
test.AssertNotError(t, err, "valid ipv4 address")
|
||||||
|
|
||||||
|
// 'enum:ipAddress'
|
||||||
|
// Valid IPv6 address.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334")
|
||||||
|
test.AssertNotError(t, err, "valid ipv6 address")
|
||||||
|
|
||||||
|
// 'enum:ipv6rangeCIDR'
|
||||||
|
// Valid IPv6 address range.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48")
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
|
||||||
|
// 'enum:regId'
|
||||||
|
// Valid regId.
|
||||||
|
err = validateIdForName(NewOrdersPerAccount, "1234567890")
|
||||||
|
test.AssertNotError(t, err, "valid regId")
|
||||||
|
|
||||||
|
// 'enum:regId:domain'
|
||||||
|
// Valid regId and domain.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com")
|
||||||
|
test.AssertNotError(t, err, "valid regId and domain")
|
||||||
|
|
||||||
|
// 'enum:regId:fqdnSet'
|
||||||
|
// Valid regId and FQDN set containing a single domain.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com")
|
||||||
|
test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain")
|
||||||
|
|
||||||
|
// 'enum:regId:fqdnSet'
|
||||||
|
// Valid regId and FQDN set containing multiple domains.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org")
|
||||||
|
test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains")
|
||||||
|
|
||||||
|
// Empty string.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPAddress, "")
|
||||||
|
test.AssertError(t, err, "Id is an empty string")
|
||||||
|
|
||||||
|
// One space.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPAddress, " ")
|
||||||
|
test.AssertError(t, err, "Id is a single space")
|
||||||
|
|
||||||
|
// Invalid IPv4 address.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.9000")
|
||||||
|
test.AssertError(t, err, "invalid IPv4 address")
|
||||||
|
|
||||||
|
// Invalid IPv6 address.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334:9000")
|
||||||
|
test.AssertError(t, err, "invalid IPv6 address")
|
||||||
|
|
||||||
|
// Invalid IPv6 CIDR range.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/128")
|
||||||
|
test.AssertError(t, err, "invalid IPv6 CIDR range")
|
||||||
|
|
||||||
|
// Invalid IPv6 CIDR.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48/48")
|
||||||
|
test.AssertError(t, err, "invalid IPv6 CIDR")
|
||||||
|
|
||||||
|
// IPv4 CIDR when we expect IPv6 CIDR range.
|
||||||
|
err = validateIdForName(NewRegistrationsPerIPv6Range, "10.0.0.0/16")
|
||||||
|
test.AssertError(t, err, "ipv4 cidr when we expect ipv6 cidr range")
|
||||||
|
|
||||||
|
// Invalid regId.
|
||||||
|
err = validateIdForName(NewOrdersPerAccount, "lol")
|
||||||
|
test.AssertError(t, err, "invalid regId")
|
||||||
|
|
||||||
|
// Invalid regId with good domain.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com")
|
||||||
|
test.AssertError(t, err, "invalid regId with good domain")
|
||||||
|
|
||||||
|
// Valid regId with bad domain.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol")
|
||||||
|
test.AssertError(t, err, "valid regId with bad domain")
|
||||||
|
|
||||||
|
// Empty regId with good domain.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, ":lol")
|
||||||
|
test.AssertError(t, err, "valid regId with bad domain")
|
||||||
|
|
||||||
|
// Valid regId with empty domain.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:")
|
||||||
|
test.AssertError(t, err, "valid regId with empty domain")
|
||||||
|
|
||||||
|
// Empty regId with empty domain, no separator.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "")
|
||||||
|
test.AssertError(t, err, "empty regId with empty domain, no separator")
|
||||||
|
|
||||||
|
// Instead of anything we would expect, we get lol.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "lol")
|
||||||
|
test.AssertError(t, err, "instead of anything we would expect, just lol")
|
||||||
|
|
||||||
|
// Valid regId with good domain and a secret third separator.
|
||||||
|
err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol")
|
||||||
|
test.AssertError(t, err, "valid regId with good domain and a secret third separator")
|
||||||
|
|
||||||
|
// Valid regId with bad FQDN set.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99")
|
||||||
|
test.AssertError(t, err, "valid regId with bad FQDN set")
|
||||||
|
|
||||||
|
// Bad regId with good FQDN set.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org")
|
||||||
|
test.AssertError(t, err, "bad regId with good FQDN set")
|
||||||
|
|
||||||
|
// Empty regId with good FQDN set.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org")
|
||||||
|
test.AssertError(t, err, "empty regId with good FQDN set")
|
||||||
|
|
||||||
|
// Good regId with empty FQDN set.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:")
|
||||||
|
test.AssertError(t, err, "good regId with empty FQDN set")
|
||||||
|
|
||||||
|
// Empty regId with empty FQDN set, no separator.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "")
|
||||||
|
test.AssertError(t, err, "empty regId with empty FQDN set, no separator")
|
||||||
|
|
||||||
|
// Instead of anything we would expect, just lol.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol")
|
||||||
|
test.AssertError(t, err, "instead of anything we would expect, just lol")
|
||||||
|
|
||||||
|
// Valid regId with good FQDN set and a secret third separator.
|
||||||
|
err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol")
|
||||||
|
test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_loadAndParseOverrideLimits(t *testing.T) {
|
||||||
|
newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
|
||||||
|
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
|
||||||
|
|
||||||
|
// Load a single valid override limit with Id formatted as 'enum:RegId'.
|
||||||
|
l, err := loadAndParseOverrideLimits("testdata/working_override.yml")
|
||||||
|
test.AssertNotError(t, err, "valid single override limit")
|
||||||
|
expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
|
||||||
|
test.AssertEquals(t, l[expectKey].Burst, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey].Count, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
|
||||||
|
|
||||||
|
// Load single valid override limit with Id formatted as 'regId:domain'.
|
||||||
|
l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml")
|
||||||
|
test.AssertNotError(t, err, "valid single override limit with Id of regId:domain")
|
||||||
|
expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com"
|
||||||
|
test.AssertEquals(t, l[expectKey].Burst, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey].Count, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey].Period.Duration, time.Second)
|
||||||
|
|
||||||
|
// Load multiple valid override limits with 'enum:RegId' Ids.
|
||||||
|
l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml")
|
||||||
|
expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2"
|
||||||
|
test.AssertNotError(t, err, "multiple valid override limits")
|
||||||
|
test.AssertEquals(t, l[expectKey1].Burst, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey1].Count, int64(40))
|
||||||
|
test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second)
|
||||||
|
expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48"
|
||||||
|
test.AssertEquals(t, l[expectKey2].Burst, int64(50))
|
||||||
|
test.AssertEquals(t, l[expectKey2].Count, int64(50))
|
||||||
|
test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2)
|
||||||
|
|
||||||
|
// Load multiple valid override limits with 'regID:fqdnSet' Ids as follows:
|
||||||
|
// - CertificatesPerFQDNSetPerAccount:12345678:example.com
|
||||||
|
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net
|
||||||
|
// - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org
|
||||||
|
firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"}))
|
||||||
|
secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"}))
|
||||||
|
thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"}))
|
||||||
|
firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash
|
||||||
|
secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash
|
||||||
|
thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash
|
||||||
|
l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml")
|
||||||
|
test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets")
|
||||||
|
test.AssertEquals(t, l[firstEntryKey].Burst, int64(40))
|
||||||
|
test.AssertEquals(t, l[firstEntryKey].Count, int64(40))
|
||||||
|
test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second)
|
||||||
|
test.AssertEquals(t, l[secondEntryKey].Burst, int64(50))
|
||||||
|
test.AssertEquals(t, l[secondEntryKey].Count, int64(50))
|
||||||
|
test.AssertEquals(t, l[secondEntryKey].Period.Duration, time.Second*2)
|
||||||
|
test.AssertEquals(t, l[thirdEntryKey].Burst, int64(60))
|
||||||
|
test.AssertEquals(t, l[thirdEntryKey].Count, int64(60))
|
||||||
|
test.AssertEquals(t, l[thirdEntryKey].Period.Duration, time.Second*3)
|
||||||
|
|
||||||
|
// Path is empty string.
|
||||||
|
_, err = loadAndParseOverrideLimits("")
|
||||||
|
test.AssertError(t, err, "path is empty string")
|
||||||
|
test.Assert(t, os.IsNotExist(err), "path is empty string")
|
||||||
|
|
||||||
|
// Path to file which does not exist.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/file_does_not_exist.yml")
|
||||||
|
test.AssertError(t, err, "a file that does not exist ")
|
||||||
|
test.Assert(t, os.IsNotExist(err), "test file should not exist")
|
||||||
|
|
||||||
|
// Burst cannot be 0.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_override_burst_0.yml")
|
||||||
|
test.AssertError(t, err, "single override limit with burst=0")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Id cannot be empty.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_id.yml")
|
||||||
|
test.AssertError(t, err, "single override limit with empty id")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Name cannot be empty.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_override_empty_name.yml")
|
||||||
|
test.AssertError(t, err, "single override limit with empty name")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Name must be a string representation of a valid Name enumeration.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_override_invalid_name.yml")
|
||||||
|
test.AssertError(t, err, "single override limit with invalid name")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Multiple entries, second entry has a bad name.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_second_entry_bad_name.yml")
|
||||||
|
test.AssertError(t, err, "multiple override limits, second entry is bad")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Multiple entries, third entry has id of "lol", instead of an IPv4 address.
|
||||||
|
_, err = loadAndParseOverrideLimits("testdata/busted_overrides_third_entry_bad_id.yml")
|
||||||
|
test.AssertError(t, err, "multiple override limits, third entry has bad Id value")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_loadAndParseDefaultLimits(t *testing.T) {
|
||||||
|
newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress)
|
||||||
|
newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range)
|
||||||
|
|
||||||
|
// Load a single valid default limit.
|
||||||
|
l, err := loadAndParseDefaultLimits("testdata/working_default.yml")
|
||||||
|
test.AssertNotError(t, err, "valid single default limit")
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
|
||||||
|
|
||||||
|
// Load multiple valid default limits.
|
||||||
|
l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml")
|
||||||
|
test.AssertNotError(t, err, "multiple valid default limits")
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20))
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20))
|
||||||
|
test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second)
|
||||||
|
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30))
|
||||||
|
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30))
|
||||||
|
test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2)
|
||||||
|
|
||||||
|
// Path is empty string.
|
||||||
|
_, err = loadAndParseDefaultLimits("")
|
||||||
|
test.AssertError(t, err, "path is empty string")
|
||||||
|
test.Assert(t, os.IsNotExist(err), "path is empty string")
|
||||||
|
|
||||||
|
// Path to file which does not exist.
|
||||||
|
_, err = loadAndParseDefaultLimits("testdata/file_does_not_exist.yml")
|
||||||
|
test.AssertError(t, err, "a file that does not exist")
|
||||||
|
test.Assert(t, os.IsNotExist(err), "test file should not exist")
|
||||||
|
|
||||||
|
// Burst cannot be 0.
|
||||||
|
_, err = loadAndParseDefaultLimits("testdata/busted_default_burst_0.yml")
|
||||||
|
test.AssertError(t, err, "single default limit with burst=0")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Name cannot be empty.
|
||||||
|
_, err = loadAndParseDefaultLimits("testdata/busted_default_empty_name.yml")
|
||||||
|
test.AssertError(t, err, "single default limit with empty name")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Name must be a string representation of a valid Name enumeration.
|
||||||
|
_, err = loadAndParseDefaultLimits("testdata/busted_default_invalid_name.yml")
|
||||||
|
test.AssertError(t, err, "single default limit with invalid name")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
|
||||||
|
// Multiple entries, second entry has a bad name.
|
||||||
|
_, err = loadAndParseDefaultLimits("testdata/busted_defaults_second_entry_bad_name.yml")
|
||||||
|
test.AssertError(t, err, "multiple default limits, one is bad")
|
||||||
|
test.Assert(t, !os.IsNotExist(err), "test file should exist")
|
||||||
|
}
|
|
@ -0,0 +1,234 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmhodges/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidCost indicates that the cost specified was <= 0.
|
||||||
|
var ErrInvalidCost = fmt.Errorf("invalid cost, must be > 0")
|
||||||
|
|
||||||
|
// ErrInvalidCostForCheck indicates that the check cost specified was < 0.
|
||||||
|
var ErrInvalidCostForCheck = fmt.Errorf("invalid check cost, must be >= 0")
|
||||||
|
|
||||||
|
// ErrInvalidCostOverLimit indicates that the cost specified was > limit.Burst.
|
||||||
|
var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst")
|
||||||
|
|
||||||
|
// ErrBucketAlreadyFull indicates that the bucket already has reached its
|
||||||
|
// maximum capacity.
|
||||||
|
var ErrBucketAlreadyFull = fmt.Errorf("bucket already full")
|
||||||
|
|
||||||
|
// Limiter provides a high-level interface for rate limiting requests by
|
||||||
|
// utilizing a leaky bucket-style approach.
|
||||||
|
type Limiter struct {
|
||||||
|
// defaults stores default limits by 'name'.
|
||||||
|
defaults limits
|
||||||
|
|
||||||
|
// overrides stores override limits by 'name:id'.
|
||||||
|
overrides limits
|
||||||
|
|
||||||
|
// source is used to store buckets. It must be safe for concurrent use.
|
||||||
|
source source
|
||||||
|
clk clock.Clock
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLimiter returns a new *Limiter. The provided source must be safe for
|
||||||
|
// concurrent use. The defaults and overrides paths are expected to be paths to
|
||||||
|
// YAML files that contain the default and override limits, respectively. The
|
||||||
|
// overrides file is optional, all other arguments are required.
|
||||||
|
func NewLimiter(clk clock.Clock, source source, defaults, overrides string) (*Limiter, error) {
|
||||||
|
limiter := &Limiter{source: source, clk: clk}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
limiter.defaults, err = loadAndParseDefaultLimits(defaults)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if overrides == "" {
|
||||||
|
// No overrides specified, initialize an empty map.
|
||||||
|
limiter.overrides = make(limits)
|
||||||
|
return limiter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter.overrides, err = loadAndParseOverrideLimits(overrides)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return limiter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Decision struct {
|
||||||
|
// Allowed is true if the bucket possessed enough capacity to allow the
|
||||||
|
// request given the cost.
|
||||||
|
Allowed bool
|
||||||
|
|
||||||
|
// Remaining is the number of requests the client is allowed to make before
|
||||||
|
// they're rate limited.
|
||||||
|
Remaining int64
|
||||||
|
|
||||||
|
// RetryIn is the duration the client MUST wait before they're allowed to
|
||||||
|
// make a request.
|
||||||
|
RetryIn time.Duration
|
||||||
|
|
||||||
|
// ResetIn is the duration the bucket will take to refill to its maximum
|
||||||
|
// capacity, assuming no further requests are made.
|
||||||
|
ResetIn time.Duration
|
||||||
|
|
||||||
|
// newTAT indicates the time at which the bucket will be full. It is the
|
||||||
|
// theoretical arrival time (TAT) of next request. It must be no more than
|
||||||
|
// (burst * (period / count)) in the future at any single point in time.
|
||||||
|
newTAT time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check returns a *Decision that indicates whether there's enough capacity to
|
||||||
|
// allow the request, given the cost, for the specified limit Name and client
|
||||||
|
// id. However, it DOES NOT deduct the cost of the request from the bucket's
|
||||||
|
// capacity. Hence, the returned *Decision represents the hypothetical state of
|
||||||
|
// the bucket if the cost WERE to be deducted. The returned *Decision will
|
||||||
|
// always include the number of remaining requests in the bucket, the required
|
||||||
|
// wait time before the client can make another request, and the time until the
|
||||||
|
// bucket refills to its maximum capacity (resets). If no bucket exists for the
|
||||||
|
// given limit Name and client id, a new one will be created WITHOUT the
|
||||||
|
// request's cost deducted from its initial capacity.
|
||||||
|
func (l *Limiter) Check(name Name, id string, cost int64) (*Decision, error) {
|
||||||
|
if cost < 0 {
|
||||||
|
return nil, ErrInvalidCostForCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, err := l.getLimit(name, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cost > limit.Burst {
|
||||||
|
return nil, ErrInvalidCostOverLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
tat, err := l.source.Get(bucketKey(name, id))
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrBucketNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// First request from this client. The cost is not deducted from the
|
||||||
|
// initial capacity because this is only a check.
|
||||||
|
d, err := l.initialize(limit, name, id, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return maybeSpend(l.clk, limit, d.newTAT, cost), nil
|
||||||
|
}
|
||||||
|
return maybeSpend(l.clk, limit, tat, cost), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spend returns a *Decision that indicates if enough capacity was available to
|
||||||
|
// process the request, given the cost, for the specified limit Name and client
|
||||||
|
// id. If capacity existed, the cost of the request HAS been deducted from the
|
||||||
|
// bucket's capacity, otherwise no cost was deducted. The returned *Decision
|
||||||
|
// will always include the number of remaining requests in the bucket, the
|
||||||
|
// required wait time before the client can make another request, and the time
|
||||||
|
// until the bucket refills to its maximum capacity (resets). If no bucket
|
||||||
|
// exists for the given limit Name and client id, a new one will be created WITH
|
||||||
|
// the request's cost deducted from its initial capacity.
|
||||||
|
func (l *Limiter) Spend(name Name, id string, cost int64) (*Decision, error) {
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil, ErrInvalidCost
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, err := l.getLimit(name, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cost > limit.Burst {
|
||||||
|
return nil, ErrInvalidCostOverLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
tat, err := l.source.Get(bucketKey(name, id))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrBucketNotFound) {
|
||||||
|
// First request from this client.
|
||||||
|
return l.initialize(limit, name, id, cost)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d := maybeSpend(l.clk, limit, tat, cost)
|
||||||
|
|
||||||
|
if !d.Allowed {
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
return d, l.source.Set(bucketKey(name, id), d.newTAT)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refund attempts to refund the cost to the bucket identified by limit name and
|
||||||
|
// client id. The returned *Decision indicates whether the refund was successful
|
||||||
|
// or not. If the refund was successful, the cost of the request was added back
|
||||||
|
// to the bucket's capacity. If the refund is not possible (i.e., the bucket is
|
||||||
|
// already full or the refund amount is invalid), no cost is refunded.
|
||||||
|
//
|
||||||
|
// Note: The amount refunded cannot cause the bucket to exceed its maximum
|
||||||
|
// capacity. However, partial refunds are allowed and are considered successful.
|
||||||
|
// For instance, if a bucket has a maximum capacity of 10 and currently has 5
|
||||||
|
// requests remaining, a refund request of 7 will result in the bucket reaching
|
||||||
|
// its maximum capacity of 10, not 12.
|
||||||
|
func (l *Limiter) Refund(name Name, id string, cost int64) (*Decision, error) {
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil, ErrInvalidCost
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, err := l.getLimit(name, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tat, err := l.source.Get(bucketKey(name, id))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := maybeRefund(l.clk, limit, tat, cost)
|
||||||
|
if !d.Allowed {
|
||||||
|
return d, ErrBucketAlreadyFull
|
||||||
|
}
|
||||||
|
return d, l.source.Set(bucketKey(name, id), d.newTAT)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the specified bucket.
|
||||||
|
func (l *Limiter) Reset(name Name, id string) error {
|
||||||
|
return l.source.Delete(bucketKey(name, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize creates a new bucket, specified by limit name and id, with the
|
||||||
|
// cost of the request factored into the initial state.
|
||||||
|
func (l *Limiter) initialize(rl limit, name Name, id string, cost int64) (*Decision, error) {
|
||||||
|
d := maybeSpend(l.clk, rl, l.clk.Now(), cost)
|
||||||
|
err := l.source.Set(bucketKey(name, id), d.newTAT)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the limit for the specified by name and id, name is
|
||||||
|
// required, id is optional. If id is left unspecified, the default limit for
|
||||||
|
// the limit specified by name is returned.
|
||||||
|
func (l *Limiter) getLimit(name Name, id string) (limit, error) {
|
||||||
|
if id != "" {
|
||||||
|
// Check for override.
|
||||||
|
ol, ok := l.overrides[bucketKey(name, id)]
|
||||||
|
if ok {
|
||||||
|
return ol, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dl, ok := l.defaults[nameToEnumString(name)]
|
||||||
|
if ok {
|
||||||
|
return dl, nil
|
||||||
|
}
|
||||||
|
return limit{}, fmt.Errorf("limit %q does not exist", name)
|
||||||
|
}
|
|
@ -0,0 +1,315 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmhodges/clock"
|
||||||
|
"github.com/letsencrypt/boulder/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tenZeroZeroOne = "10.0.0.1"
|
||||||
|
tenZeroZeroTwo = "10.0.0.2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestLimiter makes a new limiter with the following configuration:
|
||||||
|
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
|
||||||
|
func newTestLimiter(t *testing.T) (*Limiter, clock.FakeClock) {
|
||||||
|
clk := clock.NewFake()
|
||||||
|
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "")
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
return l, clk
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestLimiterWithOverrides makes a new limiter with the following
|
||||||
|
// configuration:
|
||||||
|
// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
|
||||||
|
// - 'NewRegistrationsPerIPAddress:10.0.0.2' burst: 40 count: 40 period: 1s
|
||||||
|
func newTestLimiterWithOverrides(t *testing.T) (*Limiter, clock.FakeClock) {
|
||||||
|
clk := clock.NewFake()
|
||||||
|
l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "testdata/working_override.yml")
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
return l, clk
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_initialization_via_Check_and_Spend(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
|
||||||
|
// Check on an empty bucket should initialize it and return the theoretical
|
||||||
|
// next state of that bucket if the cost were spent.
|
||||||
|
d, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(19))
|
||||||
|
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
|
||||||
|
// 1000/20 = 50 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
|
||||||
|
// However, that cost should not be spent yet, a 0 cost check should tell us
|
||||||
|
// that we actually have 20 remaining.
|
||||||
|
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(20))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Duration(0))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
|
||||||
|
// Reset our bucket.
|
||||||
|
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
|
||||||
|
// Similar to above, but we'll use Spend() instead of Check() to initialize
|
||||||
|
// the bucket. Spend should return the same result as Check.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(19))
|
||||||
|
// Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and
|
||||||
|
// 1000/20 = 50 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
|
||||||
|
// However, that cost should not be spent yet, a 0 cost check should tell us
|
||||||
|
// that we actually have 19 remaining.
|
||||||
|
d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(19))
|
||||||
|
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||||
|
// 1000/20 = 50 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_Refund_and_Spend_cost_err(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
|
||||||
|
// Spend a cost of 0, which should fail.
|
||||||
|
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||||
|
|
||||||
|
// Spend a negative cost, which should fail.
|
||||||
|
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||||
|
|
||||||
|
// Refund a cost of 0, which should fail.
|
||||||
|
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||||
|
|
||||||
|
// Refund a negative cost, which should fail.
|
||||||
|
_, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCost)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_with_bad_limits_path(t *testing.T) {
|
||||||
|
_, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "")
|
||||||
|
test.AssertError(t, err, "should error")
|
||||||
|
|
||||||
|
_, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml")
|
||||||
|
test.AssertError(t, err, "should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_Check_bad_cost(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCostForCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_Check_limit_no_exist(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
_, err := l.Check(Name(9999), tenZeroZeroOne, 1)
|
||||||
|
test.AssertError(t, err, "should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_getLimit_no_exist(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
_, err := l.getLimit(Name(9999), "")
|
||||||
|
test.AssertError(t, err, "should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_with_defaults(t *testing.T) {
|
||||||
|
l, clk := newTestLimiter(t)
|
||||||
|
|
||||||
|
// Attempt to spend 21 requests (a cost > the limit burst capacity), this
|
||||||
|
// should fail with a specific error.
|
||||||
|
_, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 21)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||||
|
|
||||||
|
// Attempt to spend all 20 requests, this should succeed.
|
||||||
|
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Attempting to spend 1 more, this should fail.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||||
|
// 1000/20 = 50 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Millisecond*50)
|
||||||
|
|
||||||
|
// Wait 50 milliseconds and try again.
|
||||||
|
clk.Add(d.RetryIn)
|
||||||
|
|
||||||
|
// We should be allowed to spend 1 more request.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Wait 1 second for a full bucket reset.
|
||||||
|
clk.Add(d.ResetIn)
|
||||||
|
|
||||||
|
// Quickly spend 20 requests in a row.
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(19-i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempting to spend 1 more, this should fail.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_with_limit_overrides(t *testing.T) {
|
||||||
|
l, clk := newTestLimiterWithOverrides(t)
|
||||||
|
|
||||||
|
// Attempt to check a spend of 41 requests (a cost > the limit burst
|
||||||
|
// capacity), this should fail with a specific error.
|
||||||
|
_, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||||
|
|
||||||
|
// Attempt to spend 41 requests (a cost > the limit burst capacity), this
|
||||||
|
// should fail with a specific error.
|
||||||
|
_, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
|
||||||
|
test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
|
||||||
|
|
||||||
|
// Attempt to spend all 40 requests, this should succeed.
|
||||||
|
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
|
||||||
|
// Attempting to spend 1 more, this should fail.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
|
||||||
|
// 1000/40 = 25 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Millisecond*25)
|
||||||
|
|
||||||
|
// Wait 50 milliseconds and try again.
|
||||||
|
clk.Add(d.RetryIn)
|
||||||
|
|
||||||
|
// We should be allowed to spend 1 more request.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Wait 1 second for a full bucket reset.
|
||||||
|
clk.Add(d.ResetIn)
|
||||||
|
|
||||||
|
// Quickly spend 40 requests in a row.
|
||||||
|
for i := 0; i < 40; i++ {
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(39-i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempting to spend 1 more, this should fail.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, !d.Allowed, "should not be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_with_new_clients(t *testing.T) {
|
||||||
|
l, _ := newTestLimiter(t)
|
||||||
|
|
||||||
|
// Attempt to spend all 20 requests, this should succeed.
|
||||||
|
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Another new client, spend 1 and check our remaining.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, "10.0.0.100", 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(19))
|
||||||
|
test.AssertEquals(t, d.RetryIn, time.Duration(0))
|
||||||
|
|
||||||
|
// 1 second == 1000 milliseconds and 1000/20 = 50 milliseconds per request.
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_Refund_and_Reset(t *testing.T) {
|
||||||
|
l, clk := newTestLimiter(t)
|
||||||
|
|
||||||
|
// Attempt to spend all 20 requests, this should succeed.
|
||||||
|
d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Refund 10 requests.
|
||||||
|
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(10))
|
||||||
|
|
||||||
|
// Spend 10 requests, this should succeed.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
|
||||||
|
// Attempt to spend 20 more requests, this should succeed.
|
||||||
|
d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.Assert(t, d.Allowed, "should be allowed")
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(0))
|
||||||
|
test.AssertEquals(t, d.ResetIn, time.Second)
|
||||||
|
|
||||||
|
// Reset to full.
|
||||||
|
clk.Add(d.ResetIn)
|
||||||
|
|
||||||
|
// Refund 1 requests above our limit, this should fail.
|
||||||
|
d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertErrorIs(t, err, ErrBucketAlreadyFull)
|
||||||
|
test.AssertEquals(t, d.Remaining, int64(20))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Limiter_Check_Spend_parity(t *testing.T) {
|
||||||
|
il, _ := newTestLimiter(t)
|
||||||
|
jl, _ := newTestLimiter(t)
|
||||||
|
i, err := il.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
j, err := jl.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1)
|
||||||
|
test.AssertNotError(t, err, "should not error")
|
||||||
|
test.AssertDeepEquals(t, i.Remaining, j.Remaining)
|
||||||
|
}
|
|
@ -0,0 +1,202 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/letsencrypt/boulder/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Name is an enumeration of all rate limit names. It is used to intern rate
|
||||||
|
// limit names as strings and to provide a type-safe way to refer to rate
|
||||||
|
// limits.
|
||||||
|
//
|
||||||
|
// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString'
|
||||||
|
// mapping and idValidForName function below.
|
||||||
|
type Name int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Unknown is the zero value of Name and is used to indicate an unknown
|
||||||
|
// limit name.
|
||||||
|
Unknown Name = iota
|
||||||
|
|
||||||
|
// NewRegistrationsPerIPAddress uses bucket key 'enum:ipAddress'.
|
||||||
|
NewRegistrationsPerIPAddress
|
||||||
|
|
||||||
|
// NewRegistrationsPerIPv6Range uses bucket key 'enum:ipv6rangeCIDR'. The
|
||||||
|
// address range must be a /48.
|
||||||
|
NewRegistrationsPerIPv6Range
|
||||||
|
|
||||||
|
// NewOrdersPerAccount uses bucket key 'enum:regId'.
|
||||||
|
NewOrdersPerAccount
|
||||||
|
|
||||||
|
// FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId
|
||||||
|
// is the registration id of the account.
|
||||||
|
FailedAuthorizationsPerAccount
|
||||||
|
|
||||||
|
// CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain',
|
||||||
|
// where name is the a name in a certificate issued to the account matching
|
||||||
|
// regId.
|
||||||
|
CertificatesPerDomainPerAccount
|
||||||
|
|
||||||
|
// CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet',
|
||||||
|
// where nameSet is a set of names in a certificate issued to the account
|
||||||
|
// matching regId.
|
||||||
|
CertificatesPerFQDNSetPerAccount
|
||||||
|
)
|
||||||
|
|
||||||
|
// nameToString is a map of Name values to string names.
|
||||||
|
var nameToString = map[Name]string{
|
||||||
|
Unknown: "Unknown",
|
||||||
|
NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress",
|
||||||
|
NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range",
|
||||||
|
NewOrdersPerAccount: "NewOrdersPerAccount",
|
||||||
|
FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount",
|
||||||
|
CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount",
|
||||||
|
CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount",
|
||||||
|
}
|
||||||
|
|
||||||
|
// validIPAddress validates that the provided string is a valid IP address.
|
||||||
|
func validIPAddress(id string) error {
|
||||||
|
ip := net.ParseIP(id)
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("invalid IP address, %q must be an IP address", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validIPv6RangeCIDR validates that the provided string is formatted is an IPv6
|
||||||
|
// CIDR range with a /48 mask.
|
||||||
|
func validIPv6RangeCIDR(id string) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid CIDR, %q must be an IPv6 CIDR range", id)
|
||||||
|
}
|
||||||
|
ones, _ := ipNet.Mask.Size()
|
||||||
|
if ones != 48 {
|
||||||
|
// This also catches the case where the range is an IPv4 CIDR, since an
|
||||||
|
// IPv4 CIDR can't have a /48 subnet mask - the maximum is /32.
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid CIDR, %q must be /48", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRegId validates that the provided string is a valid ACME regId.
|
||||||
|
func validateRegId(id string) error {
|
||||||
|
_, err := strconv.ParseUint(id, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid regId, %q must be an ACME registration Id", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRegIdDomain validates that the provided string is formatted
|
||||||
|
// 'regId:domain', where regId is an ACME registration Id and domain is a single
|
||||||
|
// domain name.
|
||||||
|
func validateRegIdDomain(id string) error {
|
||||||
|
parts := strings.SplitN(id, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid regId:domain, %q must be formatted 'regId:domain'", id)
|
||||||
|
}
|
||||||
|
if validateRegId(parts[0]) != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid regId, %q must be formatted 'regId:domain'", id)
|
||||||
|
}
|
||||||
|
if policy.ValidDomain(parts[1]) != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid domain, %q must be formatted 'regId:domain'", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRegIdFQDNSet validates that the provided string is formatted
|
||||||
|
// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a
|
||||||
|
// comma-separated list of domain names.
|
||||||
|
func validateRegIdFQDNSet(id string) error {
|
||||||
|
parts := strings.SplitN(id, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
|
||||||
|
}
|
||||||
|
if validateRegId(parts[0]) != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid regId, %q must be formatted 'regId:fqdnSet'", id)
|
||||||
|
}
|
||||||
|
domains := strings.Split(parts[1], ",")
|
||||||
|
if len(domains) == 0 {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id)
|
||||||
|
}
|
||||||
|
for _, domain := range domains {
|
||||||
|
if policy.ValidDomain(domain) != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"invalid domain, %q must be formatted 'regId:fqdnSet'", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateIdForName(name Name, id string) error {
|
||||||
|
switch name {
|
||||||
|
case NewRegistrationsPerIPAddress:
|
||||||
|
// 'enum:ipaddress'
|
||||||
|
return validIPAddress(id)
|
||||||
|
|
||||||
|
case NewRegistrationsPerIPv6Range:
|
||||||
|
// 'enum:ipv6rangeCIDR'
|
||||||
|
return validIPv6RangeCIDR(id)
|
||||||
|
|
||||||
|
case NewOrdersPerAccount, FailedAuthorizationsPerAccount:
|
||||||
|
// 'enum:regId'
|
||||||
|
return validateRegId(id)
|
||||||
|
|
||||||
|
case CertificatesPerDomainPerAccount:
|
||||||
|
// 'enum:regId:domain'
|
||||||
|
return validateRegIdDomain(id)
|
||||||
|
|
||||||
|
case CertificatesPerFQDNSetPerAccount:
|
||||||
|
// 'enum:regId:fqdnSet'
|
||||||
|
return validateRegIdFQDNSet(id)
|
||||||
|
|
||||||
|
case Unknown:
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
default:
|
||||||
|
// This should never happen.
|
||||||
|
return fmt.Errorf("unknown limit enum %q", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringToName is a map of string names to Name values.
|
||||||
|
var stringToName = func() map[string]Name {
|
||||||
|
m := make(map[string]Name, len(nameToString))
|
||||||
|
for k, v := range nameToString {
|
||||||
|
m[v] = k
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}()
|
||||||
|
|
||||||
|
// limitNames is a slice of all rate limit names.
|
||||||
|
var limitNames = func() []string {
|
||||||
|
names := make([]string, len(nameToString))
|
||||||
|
for _, v := range nameToString {
|
||||||
|
names = append(names, v)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}()
|
||||||
|
|
||||||
|
// nameToEnumString converts the integer value of the Name enumeration to its
|
||||||
|
// string representation.
|
||||||
|
func nameToEnumString(s Name) string {
|
||||||
|
return strconv.Itoa(int(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// bucketKey returns the key used to store a rate limit bucket.
|
||||||
|
func bucketKey(name Name, id string) string {
|
||||||
|
return nameToEnumString(name) + ":" + id
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package ratelimits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrBucketNotFound indicates that the bucket was not found.
|
||||||
|
var ErrBucketNotFound = fmt.Errorf("bucket not found")
|
||||||
|
|
||||||
|
// source is an interface for creating and modifying TATs.
|
||||||
|
type source interface {
|
||||||
|
// Set stores the TAT at the specified bucketKey ('name:id').
|
||||||
|
Set(bucketKey string, tat time.Time) error
|
||||||
|
|
||||||
|
// Get retrieves the TAT at the specified bucketKey ('name:id').
|
||||||
|
Get(bucketKey string) (time.Time, error)
|
||||||
|
|
||||||
|
// Delete deletes the TAT at the specified bucketKey ('name:id').
|
||||||
|
Delete(bucketKey string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// inmem is an in-memory implementation of the source interface used for
|
||||||
|
// testing.
|
||||||
|
type inmem struct {
|
||||||
|
sync.RWMutex
|
||||||
|
m map[string]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInmem() *inmem {
|
||||||
|
return &inmem{m: make(map[string]time.Time)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (in *inmem) Set(bucketKey string, tat time.Time) error {
|
||||||
|
in.Lock()
|
||||||
|
defer in.Unlock()
|
||||||
|
in.m[bucketKey] = tat
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (in *inmem) Get(bucketKey string) (time.Time, error) {
|
||||||
|
in.RLock()
|
||||||
|
defer in.RUnlock()
|
||||||
|
tat, ok := in.m[bucketKey]
|
||||||
|
if !ok {
|
||||||
|
return time.Time{}, ErrBucketNotFound
|
||||||
|
}
|
||||||
|
return tat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (in *inmem) Delete(bucketKey string) error {
|
||||||
|
in.Lock()
|
||||||
|
defer in.Unlock()
|
||||||
|
delete(in.m, bucketKey)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
NewRegistrationsPerIPAddress:
|
||||||
|
burst: 0
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
"":
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
UsageRequestsPerIPv10Address:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
|
@ -0,0 +1,8 @@
|
||||||
|
NewRegistrationsPerIPAddress:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
||||||
|
UsageRequestsPerIPv10Address:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 0
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
"UsageRequestsPerIPv10Address:":
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
":10.0.0.2":
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
UsageRequestsPerIPv10Address:10.0.0.2:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,8 @@
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
UsageRequestsPerIPv10Address:10.0.0.5:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,12 @@
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.5:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
NewRegistrationsPerIPAddress:lol:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
NewRegistrationsPerIPAddress:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
|
@ -0,0 +1,8 @@
|
||||||
|
NewRegistrationsPerIPAddress:
|
||||||
|
burst: 20
|
||||||
|
count: 20
|
||||||
|
period: 1s
|
||||||
|
NewRegistrationsPerIPv6Range:
|
||||||
|
burst: 30
|
||||||
|
count: 30
|
||||||
|
period: 2s
|
|
@ -0,0 +1,4 @@
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,4 @@
|
||||||
|
CertificatesPerDomainPerAccount:12345678:example.com:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
|
@ -0,0 +1,8 @@
|
||||||
|
NewRegistrationsPerIPAddress:10.0.0.2:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
NewRegistrationsPerIPv6Range:2001:0db8:0000::/48:
|
||||||
|
burst: 50
|
||||||
|
count: 50
|
||||||
|
period: 2s
|
|
@ -0,0 +1,12 @@
|
||||||
|
CertificatesPerFQDNSetPerAccount:12345678:example.com:
|
||||||
|
burst: 40
|
||||||
|
count: 40
|
||||||
|
period: 1s
|
||||||
|
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net:
|
||||||
|
burst: 50
|
||||||
|
count: 50
|
||||||
|
period: 2s
|
||||||
|
CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org:
|
||||||
|
burst: 60
|
||||||
|
count: 60
|
||||||
|
period: 3s
|
13
sa/model.go
13
sa/model.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
@ -873,14 +872,6 @@ type crlEntryModel struct {
|
||||||
RevokedDate time.Time `db:"revokedDate"`
|
RevokedDate time.Time `db:"revokedDate"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HashNames returns a hash of the names requested. This is intended for use
|
|
||||||
// when interacting with the orderFqdnSets table.
|
|
||||||
func HashNames(names []string) []byte {
|
|
||||||
names = core.UniqueLowerNames(names)
|
|
||||||
hash := sha256.Sum256([]byte(strings.Join(names, ",")))
|
|
||||||
return hash[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// orderFQDNSet contains the SHA256 hash of the lowercased, comma joined names
|
// orderFQDNSet contains the SHA256 hash of the lowercased, comma joined names
|
||||||
// from a new-order request, along with the corresponding orderID, the
|
// from a new-order request, along with the corresponding orderID, the
|
||||||
// registration ID, and the order expiry. This is used to find
|
// registration ID, and the order expiry. This is used to find
|
||||||
|
@ -895,7 +886,7 @@ type orderFQDNSet struct {
|
||||||
|
|
||||||
func addFQDNSet(ctx context.Context, db db.Inserter, names []string, serial string, issued time.Time, expires time.Time) error {
|
func addFQDNSet(ctx context.Context, db db.Inserter, names []string, serial string, issued time.Time, expires time.Time) error {
|
||||||
return db.Insert(ctx, &core.FQDNSet{
|
return db.Insert(ctx, &core.FQDNSet{
|
||||||
SetHash: HashNames(names),
|
SetHash: core.HashNames(names),
|
||||||
Serial: serial,
|
Serial: serial,
|
||||||
Issued: issued,
|
Issued: issued,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
|
@ -914,7 +905,7 @@ func addOrderFQDNSet(
|
||||||
regID int64,
|
regID int64,
|
||||||
expires time.Time) error {
|
expires time.Time) error {
|
||||||
return db.Insert(ctx, &orderFQDNSet{
|
return db.Insert(ctx, &orderFQDNSet{
|
||||||
SetHash: HashNames(names),
|
SetHash: core.HashNames(names),
|
||||||
OrderID: orderID,
|
OrderID: orderID,
|
||||||
RegistrationID: regID,
|
RegistrationID: regID,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
|
|
|
@ -2893,33 +2893,6 @@ func TestBlockedKeyRevokedBy(t *testing.T) {
|
||||||
test.AssertNotError(t, err, "AddBlockedKey failed")
|
test.AssertNotError(t, err, "AddBlockedKey failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHashNames(t *testing.T) {
|
|
||||||
// Test that it is deterministic
|
|
||||||
h1 := HashNames([]string{"a"})
|
|
||||||
h2 := HashNames([]string{"a"})
|
|
||||||
test.AssertByteEquals(t, h1, h2)
|
|
||||||
|
|
||||||
// Test that it differentiates
|
|
||||||
h1 = HashNames([]string{"a"})
|
|
||||||
h2 = HashNames([]string{"b"})
|
|
||||||
test.Assert(t, !bytes.Equal(h1, h2), "Should have been different")
|
|
||||||
|
|
||||||
// Test that it is not subject to ordering
|
|
||||||
h1 = HashNames([]string{"a", "b"})
|
|
||||||
h2 = HashNames([]string{"b", "a"})
|
|
||||||
test.AssertByteEquals(t, h1, h2)
|
|
||||||
|
|
||||||
// Test that it is not subject to case
|
|
||||||
h1 = HashNames([]string{"a", "b"})
|
|
||||||
h2 = HashNames([]string{"A", "B"})
|
|
||||||
test.AssertByteEquals(t, h1, h2)
|
|
||||||
|
|
||||||
// Test that it is not subject to duplication
|
|
||||||
h1 = HashNames([]string{"a", "a"})
|
|
||||||
h2 = HashNames([]string{"a"})
|
|
||||||
test.AssertByteEquals(t, h1, h2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIncidentsForSerial(t *testing.T) {
|
func TestIncidentsForSerial(t *testing.T) {
|
||||||
sa, _, cleanUp := initSA(t)
|
sa, _, cleanUp := initSA(t)
|
||||||
defer cleanUp()
|
defer cleanUp()
|
||||||
|
|
|
@ -516,7 +516,7 @@ func (ssa *SQLStorageAuthorityRO) CountFQDNSets(ctx context.Context, req *sapb.C
|
||||||
`SELECT COUNT(*) FROM fqdnSets
|
`SELECT COUNT(*) FROM fqdnSets
|
||||||
WHERE setHash = ?
|
WHERE setHash = ?
|
||||||
AND issued > ?`,
|
AND issued > ?`,
|
||||||
HashNames(req.Domains),
|
core.HashNames(req.Domains),
|
||||||
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
||||||
)
|
)
|
||||||
return &sapb.Count{Count: count}, err
|
return &sapb.Count{Count: count}, err
|
||||||
|
@ -544,7 +544,7 @@ func (ssa *SQLStorageAuthorityRO) FQDNSetTimestampsForWindow(ctx context.Context
|
||||||
WHERE setHash = ?
|
WHERE setHash = ?
|
||||||
AND issued > ?
|
AND issued > ?
|
||||||
ORDER BY issued DESC`,
|
ORDER BY issued DESC`,
|
||||||
HashNames(req.Domains),
|
core.HashNames(req.Domains),
|
||||||
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
ssa.clk.Now().Add(-time.Duration(req.Window)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -586,7 +586,7 @@ type oneSelectorFunc func(ctx context.Context, holder interface{}, query string,
|
||||||
// checkFQDNSetExists uses the given oneSelectorFunc to check whether an fqdnSet
|
// checkFQDNSetExists uses the given oneSelectorFunc to check whether an fqdnSet
|
||||||
// for the given names exists.
|
// for the given names exists.
|
||||||
func (ssa *SQLStorageAuthorityRO) checkFQDNSetExists(ctx context.Context, selector oneSelectorFunc, names []string) (bool, error) {
|
func (ssa *SQLStorageAuthorityRO) checkFQDNSetExists(ctx context.Context, selector oneSelectorFunc, names []string) (bool, error) {
|
||||||
namehash := HashNames(names)
|
namehash := core.HashNames(names)
|
||||||
var exists bool
|
var exists bool
|
||||||
err := selector(
|
err := selector(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -761,7 +761,7 @@ func (ssa *SQLStorageAuthorityRO) GetOrderForNames(ctx context.Context, req *sap
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash the names requested for lookup in the orderFqdnSets table
|
// Hash the names requested for lookup in the orderFqdnSets table
|
||||||
fqdnHash := HashNames(req.Names)
|
fqdnHash := core.HashNames(req.Names)
|
||||||
|
|
||||||
// Find a possibly-suitable order. We don't include the account ID or order
|
// Find a possibly-suitable order. We don't include the account ID or order
|
||||||
// status in this query because there's no index that includes those, so
|
// status in this query because there's no index that includes those, so
|
||||||
|
|
Loading…
Reference in New Issue