280 lines
8.4 KiB
Go
280 lines
8.4 KiB
Go
package ratelimits
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/jmhodges/clock"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// Compile-time check that RedisSource implements the source interface.
|
|
var _ Source = (*RedisSource)(nil)
|
|
|
|
// RedisSource is a ratelimits source backed by sharded Redis.
|
|
type RedisSource struct {
|
|
client *redis.Ring
|
|
clk clock.Clock
|
|
latency *prometheus.HistogramVec
|
|
}
|
|
|
|
// NewRedisSource returns a new Redis backed source using the provided
|
|
// *redis.Ring client.
|
|
func NewRedisSource(client *redis.Ring, clk clock.Clock, stats prometheus.Registerer) *RedisSource {
|
|
latency := prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "ratelimits_latency",
|
|
Help: "Histogram of Redis call latencies labeled by call=[set|get|delete|ping] and result=[success|error]",
|
|
// Exponential buckets ranging from 0.0005s to 3s.
|
|
Buckets: prometheus.ExponentialBucketsRange(0.0005, 3, 8),
|
|
},
|
|
[]string{"call", "result"},
|
|
)
|
|
stats.MustRegister(latency)
|
|
|
|
return &RedisSource{
|
|
client: client,
|
|
clk: clk,
|
|
latency: latency,
|
|
}
|
|
}
|
|
|
|
var errMixedSuccess = errors.New("some keys not found")
|
|
|
|
// resultForError returns a string representing the result of the operation
|
|
// based on the provided error.
|
|
func resultForError(err error) string {
|
|
if errors.Is(errMixedSuccess, err) {
|
|
// Indicates that some of the keys in a batchset operation were not found.
|
|
return "mixedSuccess"
|
|
} else if errors.Is(redis.Nil, err) {
|
|
// Bucket key does not exist.
|
|
return "notFound"
|
|
} else if errors.Is(err, context.DeadlineExceeded) {
|
|
// Client read or write deadline exceeded.
|
|
return "deadlineExceeded"
|
|
} else if errors.Is(err, context.Canceled) {
|
|
// Caller canceled the operation.
|
|
return "canceled"
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
// Dialer timed out connecting to Redis.
|
|
return "timeout"
|
|
}
|
|
var redisErr redis.Error
|
|
if errors.Is(err, redisErr) {
|
|
// An internal error was returned by the Redis server.
|
|
return "redisError"
|
|
}
|
|
return "failed"
|
|
}
|
|
|
|
func (r *RedisSource) observeLatency(call string, latency time.Duration, err error) {
|
|
result := "success"
|
|
if err != nil {
|
|
result = resultForError(err)
|
|
}
|
|
r.latency.With(prometheus.Labels{"call": call, "result": result}).Observe(latency.Seconds())
|
|
}
|
|
|
|
// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis
|
|
// Transaction in order to reduce the number of round-trips to each Redis shard.
|
|
func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error {
|
|
start := r.clk.Now()
|
|
|
|
pipeline := r.client.Pipeline()
|
|
for bucketKey, tat := range buckets {
|
|
// Set a TTL of TAT + 10 minutes to account for clock skew.
|
|
ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute
|
|
pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), ttl)
|
|
}
|
|
_, err := pipeline.Exec(ctx)
|
|
if err != nil {
|
|
r.observeLatency("batchset", r.clk.Since(start), err)
|
|
return err
|
|
}
|
|
|
|
totalLatency := r.clk.Since(start)
|
|
perSetLatency := totalLatency / time.Duration(len(buckets))
|
|
for range buckets {
|
|
r.observeLatency("batchset_entry", perSetLatency, nil)
|
|
}
|
|
|
|
r.observeLatency("batchset", totalLatency, nil)
|
|
return nil
|
|
}
|
|
|
|
// BatchSetNotExisting attempts to set TATs for the specified bucketKeys if they
|
|
// do not already exist. Returns a map indicating which keys already existed.
|
|
func (r *RedisSource) BatchSetNotExisting(ctx context.Context, buckets map[string]time.Time) (map[string]bool, error) {
|
|
start := r.clk.Now()
|
|
|
|
pipeline := r.client.Pipeline()
|
|
cmds := make(map[string]*redis.BoolCmd, len(buckets))
|
|
for bucketKey, tat := range buckets {
|
|
// Set a TTL of TAT + 10 minutes to account for clock skew.
|
|
ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute
|
|
cmds[bucketKey] = pipeline.SetNX(ctx, bucketKey, tat.UTC().UnixNano(), ttl)
|
|
}
|
|
_, err := pipeline.Exec(ctx)
|
|
if err != nil {
|
|
r.observeLatency("batchsetnotexisting", r.clk.Since(start), err)
|
|
return nil, err
|
|
}
|
|
|
|
alreadyExists := make(map[string]bool, len(buckets))
|
|
totalLatency := r.clk.Since(start)
|
|
perSetLatency := totalLatency / time.Duration(len(buckets))
|
|
for bucketKey, cmd := range cmds {
|
|
success, err := cmd.Result()
|
|
if err != nil {
|
|
r.observeLatency("batchsetnotexisting_entry", perSetLatency, err)
|
|
return nil, err
|
|
}
|
|
if !success {
|
|
alreadyExists[bucketKey] = true
|
|
}
|
|
r.observeLatency("batchsetnotexisting_entry", perSetLatency, nil)
|
|
}
|
|
|
|
r.observeLatency("batchsetnotexisting", totalLatency, nil)
|
|
return alreadyExists, nil
|
|
}
|
|
|
|
// BatchIncrement updates TATs for the specified bucketKeys using a pipelined
|
|
// Redis Transaction in order to reduce the number of round-trips to each Redis
|
|
// shard.
|
|
func (r *RedisSource) BatchIncrement(ctx context.Context, buckets map[string]increment) error {
|
|
start := r.clk.Now()
|
|
|
|
pipeline := r.client.Pipeline()
|
|
for bucketKey, incr := range buckets {
|
|
pipeline.IncrBy(ctx, bucketKey, incr.cost.Nanoseconds())
|
|
pipeline.Expire(ctx, bucketKey, incr.ttl)
|
|
}
|
|
_, err := pipeline.Exec(ctx)
|
|
if err != nil {
|
|
r.observeLatency("batchincrby", r.clk.Since(start), err)
|
|
return err
|
|
}
|
|
|
|
totalLatency := r.clk.Since(start)
|
|
perSetLatency := totalLatency / time.Duration(len(buckets))
|
|
for range buckets {
|
|
r.observeLatency("batchincrby_entry", perSetLatency, nil)
|
|
}
|
|
|
|
r.observeLatency("batchincrby", totalLatency, nil)
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves the TAT at the specified bucketKey. If the bucketKey does not
|
|
// exist, ErrBucketNotFound is returned.
|
|
func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) {
|
|
start := r.clk.Now()
|
|
|
|
tatNano, err := r.client.Get(ctx, bucketKey).Int64()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
// Bucket key does not exist.
|
|
r.observeLatency("get", r.clk.Since(start), err)
|
|
return time.Time{}, ErrBucketNotFound
|
|
}
|
|
// An error occurred while retrieving the TAT.
|
|
r.observeLatency("get", r.clk.Since(start), err)
|
|
return time.Time{}, err
|
|
}
|
|
|
|
r.observeLatency("get", r.clk.Since(start), nil)
|
|
return time.Unix(0, tatNano).UTC(), nil
|
|
}
|
|
|
|
// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined
|
|
// Redis Transaction in order to reduce the number of round-trips to each Redis
|
|
// shard. If a bucketKey does not exist, it WILL NOT be included in the returned
|
|
// map.
|
|
func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) {
|
|
start := r.clk.Now()
|
|
|
|
pipeline := r.client.Pipeline()
|
|
for _, bucketKey := range bucketKeys {
|
|
pipeline.Get(ctx, bucketKey)
|
|
}
|
|
results, err := pipeline.Exec(ctx)
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
r.observeLatency("batchget", r.clk.Since(start), err)
|
|
return nil, err
|
|
}
|
|
|
|
totalLatency := r.clk.Since(start)
|
|
perEntryLatency := totalLatency / time.Duration(len(bucketKeys))
|
|
|
|
tats := make(map[string]time.Time, len(bucketKeys))
|
|
notFoundCount := 0
|
|
for i, result := range results {
|
|
tatNano, err := result.(*redis.StringCmd).Int64()
|
|
if err != nil {
|
|
if !errors.Is(err, redis.Nil) {
|
|
// This should never happen as any errors should have been
|
|
// caught after the pipeline.Exec() call.
|
|
r.observeLatency("batchget", r.clk.Since(start), err)
|
|
return nil, err
|
|
}
|
|
// Bucket key does not exist.
|
|
r.observeLatency("batchget_entry", perEntryLatency, err)
|
|
notFoundCount++
|
|
continue
|
|
}
|
|
tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC()
|
|
r.observeLatency("batchget_entry", perEntryLatency, nil)
|
|
}
|
|
|
|
var batchErr error
|
|
if notFoundCount < len(results) {
|
|
// Some keys were not found.
|
|
batchErr = errMixedSuccess
|
|
} else if notFoundCount == len(results) {
|
|
// All keys were not found.
|
|
batchErr = redis.Nil
|
|
}
|
|
|
|
r.observeLatency("batchget", totalLatency, batchErr)
|
|
return tats, nil
|
|
}
|
|
|
|
// Delete deletes the TAT at the specified bucketKey ('name:id'). A nil return
|
|
// value does not indicate that the bucketKey existed.
|
|
func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error {
|
|
start := r.clk.Now()
|
|
|
|
err := r.client.Del(ctx, bucketKey).Err()
|
|
if err != nil {
|
|
r.observeLatency("delete", r.clk.Since(start), err)
|
|
return err
|
|
}
|
|
|
|
r.observeLatency("delete", r.clk.Since(start), nil)
|
|
return nil
|
|
}
|
|
|
|
// Ping checks that each shard of the *redis.Ring is reachable using the PING
|
|
// command.
|
|
func (r *RedisSource) Ping(ctx context.Context) error {
|
|
start := r.clk.Now()
|
|
|
|
err := r.client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
|
|
return shard.Ping(ctx).Err()
|
|
})
|
|
if err != nil {
|
|
r.observeLatency("ping", r.clk.Since(start), err)
|
|
return err
|
|
}
|
|
|
|
r.observeLatency("ping", r.clk.Since(start), nil)
|
|
return nil
|
|
}
|