boulder/ratelimits/source_redis.go

180 lines
6.0 KiB
Go

package ratelimits
import (
"context"
"errors"
"net"
"time"
"github.com/jmhodges/clock"
"github.com/prometheus/client_golang/prometheus"
"github.com/redis/go-redis/v9"
)
// Compile-time check that RedisSource implements the source interface.
var _ source = (*RedisSource)(nil)
// RedisSource is a ratelimits source backed by sharded Redis.
type RedisSource struct {
client *redis.Ring
clk clock.Clock
latency *prometheus.HistogramVec
}
// NewRedisSource returns a new Redis backed source using the provided
// *redis.Ring client.
func NewRedisSource(client *redis.Ring, clk clock.Clock, stats prometheus.Registerer) *RedisSource {
latency := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "ratelimits_latency",
Help: "Histogram of Redis call latencies labeled by call=[set|get|delete|ping] and result=[success|error]",
// Exponential buckets ranging from 0.0005s to 3s.
Buckets: prometheus.ExponentialBucketsRange(0.0005, 3, 8),
},
[]string{"call", "result"},
)
stats.MustRegister(latency)
return &RedisSource{
client: client,
clk: clk,
latency: latency,
}
}
// resultForError returns a string representing the result of the operation
// based on the provided error.
func resultForError(err error) string {
if errors.Is(redis.Nil, err) {
// Bucket key does not exist.
return "notFound"
} else if errors.Is(err, context.DeadlineExceeded) {
// Client read or write deadline exceeded.
return "deadlineExceeded"
} else if errors.Is(err, context.Canceled) {
// Caller canceled the operation.
return "canceled"
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// Dialer timed out connecting to Redis.
return "timeout"
}
var redisErr redis.Error
if errors.Is(err, redisErr) {
// An internal error was returned by the Redis server.
return "redisError"
}
return "failed"
}
// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis
// Transaction in order to reduce the number of round-trips to each Redis shard.
// An error is returned if the operation failed and nil otherwise.
func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error {
start := r.clk.Now()
pipeline := r.client.Pipeline()
for bucketKey, tat := range buckets {
pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), 0)
}
_, err := pipeline.Exec(ctx)
if err != nil {
r.latency.With(prometheus.Labels{"call": "batchset", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
return err
}
r.latency.With(prometheus.Labels{"call": "batchset", "result": "success"}).Observe(time.Since(start).Seconds())
return nil
}
// Get retrieves the TAT at the specified bucketKey. An error is returned if the
// operation failed and nil otherwise. If the bucketKey does not exist,
// ErrBucketNotFound is returned.
func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) {
start := r.clk.Now()
tatNano, err := r.client.Get(ctx, bucketKey).Int64()
if err != nil {
if errors.Is(err, redis.Nil) {
// Bucket key does not exist.
r.latency.With(prometheus.Labels{"call": "get", "result": "notFound"}).Observe(time.Since(start).Seconds())
return time.Time{}, ErrBucketNotFound
}
r.latency.With(prometheus.Labels{"call": "get", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
return time.Time{}, err
}
r.latency.With(prometheus.Labels{"call": "get", "result": "success"}).Observe(time.Since(start).Seconds())
return time.Unix(0, tatNano).UTC(), nil
}
// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined
// Redis Transaction in order to reduce the number of round-trips to each Redis
// shard. An error is returned if the operation failed and nil otherwise. If a
// bucketKey does not exist, it WILL NOT be included in the returned map.
func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) {
start := r.clk.Now()
pipeline := r.client.Pipeline()
for _, bucketKey := range bucketKeys {
pipeline.Get(ctx, bucketKey)
}
results, err := pipeline.Exec(ctx)
if err != nil {
r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
if !errors.Is(err, redis.Nil) {
return nil, err
}
}
tats := make(map[string]time.Time, len(bucketKeys))
for i, result := range results {
tatNano, err := result.(*redis.StringCmd).Int64()
if err != nil {
if errors.Is(err, redis.Nil) {
// Bucket key does not exist.
continue
}
r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
return nil, err
}
tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC()
}
r.latency.With(prometheus.Labels{"call": "batchget", "result": "success"}).Observe(time.Since(start).Seconds())
return tats, nil
}
// Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an
// error if the operation failed and nil otherwise. A nil return value does not
// indicate that the bucketKey existed.
func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error {
start := r.clk.Now()
err := r.client.Del(ctx, bucketKey).Err()
if err != nil {
r.latency.With(prometheus.Labels{"call": "delete", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
return err
}
r.latency.With(prometheus.Labels{"call": "delete", "result": "success"}).Observe(time.Since(start).Seconds())
return nil
}
// Ping checks that each shard of the *redis.Ring is reachable using the PING
// command. It returns an error if any shard is unreachable and nil otherwise.
func (r *RedisSource) Ping(ctx context.Context) error {
start := r.clk.Now()
err := r.client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
return shard.Ping(ctx).Err()
})
if err != nil {
r.latency.With(prometheus.Labels{"call": "ping", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
return err
}
r.latency.With(prometheus.Labels{"call": "ping", "result": "success"}).Observe(time.Since(start).Seconds())
return nil
}