components-contrib/common/component/redis/redis.go

402 lines
14 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package redis
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/cenkalti/backoff/v4"
"github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/mod/semver"
"github.com/dapr/components-contrib/common/authentication/azure"
"github.com/dapr/components-contrib/configuration"
"github.com/dapr/components-contrib/metadata"
kitlogger "github.com/dapr/kit/logger"
kitretry "github.com/dapr/kit/retry"
)
const (
ClusterType = "cluster"
NodeType = "node"
processingTimeoutKey = "processingTimeout"
redeliverIntervalKey = "redeliverInterval"
redisMinRetryIntervalKey = "redisMinRetryInterval"
maxRetryBackoffKey = "maxRetryBackoff"
redisMaxRetriesKey = "redisMaxRetries"
maxRetriesKey = "maxRetries"
)
type RedisXMessage struct {
ID string
Values map[string]interface{}
}
type RedisXStream struct {
Stream string
Messages []RedisXMessage
}
type RedisXPendingExt struct {
ID string
Consumer string
Idle time.Duration
RetryCount int64
}
type RedisPipeliner interface {
Exec(ctx context.Context) error
Do(ctx context.Context, args ...interface{})
}
//nolint:interfacebloat
type RedisClient interface {
GetNilValueError() RedisError
Context() context.Context
DoRead(ctx context.Context, args ...interface{}) (interface{}, error)
DoWrite(ctx context.Context, args ...interface{}) error
Del(ctx context.Context, keys ...string) error
Get(ctx context.Context, key string) (string, error)
GetDel(ctx context.Context, key string) (string, error)
Close() error
PingResult(ctx context.Context) (string, error)
ConfigurationSubscribe(ctx context.Context, args *ConfigurationSubscribeArgs)
SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) (*bool, error)
EvalInt(ctx context.Context, script string, keys []string, args ...interface{}) (*int, error, error)
XAdd(ctx context.Context, stream string, maxLenApprox int64, streamTTL string, values map[string]interface{}) (string, error)
XGroupCreateMkStream(ctx context.Context, stream string, group string, start string) error
XAck(ctx context.Context, stream string, group string, messageID string) error
XReadGroupResult(ctx context.Context, group string, consumer string, streams []string, count int64, block time.Duration) ([]RedisXStream, error)
XPendingExtResult(ctx context.Context, stream string, group string, start string, end string, count int64) ([]RedisXPendingExt, error)
XClaimResult(ctx context.Context, stream string, group string, consumer string, minIdleTime time.Duration, messageIDs []string) ([]RedisXMessage, error)
TxPipeline() RedisPipeliner
TTLResult(ctx context.Context, key string) (time.Duration, error)
AuthACL(ctx context.Context, username, password string) error
}
type ConfigurationSubscribeArgs struct {
HandleSubscribedChange func(ctx context.Context, req *configuration.SubscribeRequest, handler configuration.UpdateHandler, channel string, id string)
Req *configuration.SubscribeRequest
Handler configuration.UpdateHandler
RedisChannel string
IsAllKeysChannel bool
ID string
}
func ParseClientFromProperties(properties map[string]string, componentType metadata.ComponentType, ctx context.Context, logger *kitlogger.Logger) (RedisClient, *Settings, error) {
settings := Settings{}
// upgrade legacy metadata properties and set defaults
switch componentType {
case metadata.ConfigurationStoreType:
// Apply legacy defaults
settings.RedisMaxRetries = 3
settings.RedisMaxRetryInterval = Duration(2 * time.Second)
settings.RedisMinRetryInterval = Duration(8 * time.Millisecond)
case metadata.StateStoreType, metadata.LockStoreType:
// Apply legacy defaults
settings.RedisMaxRetries = 3
settings.RedisMinRetryInterval = Duration(2 * time.Second)
// Parse legacy keys
if properties[redisMinRetryIntervalKey] == "" {
if properties[maxRetryBackoffKey] != "" {
// due to different duration formats, do not simply change the key name
parsedVal, err := strconv.ParseInt(properties[maxRetryBackoffKey], 10, 0)
if err != nil {
return nil, nil, fmt.Errorf("redis store error: can't parse maxRetryBackoff field: %s", err)
}
settings.RedisMinRetryInterval = Duration(time.Duration(parsedVal))
}
}
if properties[redisMaxRetriesKey] == "" {
if properties[maxRetriesKey] != "" {
properties[redisMaxRetriesKey] = properties[maxRetriesKey]
}
}
case metadata.PubSubType:
settings.ProcessingTimeout = 60 * time.Second
settings.RedeliverInterval = 15 * time.Second
settings.QueueDepth = 100
settings.Concurrency = 10
}
err := settings.Decode(properties)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
switch componentType {
case metadata.PubSubType:
if val, ok := properties[processingTimeoutKey]; ok && val != "" {
if processingTimeoutMs, parseErr := strconv.ParseUint(val, 10, 64); parseErr == nil {
// because of legacy reasons, we need to interpret a number as milliseconds
// the library would default to seconds otherwise
settings.ProcessingTimeout = time.Duration(processingTimeoutMs) * time.Millisecond //nolint:gosec
}
// if there was an error we would try to interpret it as a duration string, which was already done in Decode()
}
if val, ok := properties[redeliverIntervalKey]; ok && val != "" {
if redeliverIntervalMs, parseErr := strconv.ParseUint(val, 10, 64); parseErr == nil {
// because of legacy reasons, we need to interpret a number as milliseconds
// the library would default to seconds otherwise
settings.RedeliverInterval = time.Duration(redeliverIntervalMs) * time.Millisecond //nolint:gosec
}
// if there was an error we would try to interpret it as a duration string, which was already done in Decode()
}
}
var tokenExpires *time.Time
var tokenCredential *azcore.TokenCredential
if settings.UseEntraID {
tokenExpires, tokenCredential, err = settings.GetEntraIDCredentialAndSetInitialTokenAsPassword(ctx, &properties)
if err != nil {
return nil, nil, err
}
}
var c RedisClient
newClientFunc := newV8Client
if settings.Failover {
newClientFunc = newV8FailoverClient
}
c, err = newClientFunc(&settings)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
version, err := GetServerVersion(c)
closeErr := c.Close() // close the client to avoid leaking connections
if closeErr != nil {
return nil, nil, closeErr
}
useNewClient := false
if err != nil {
// we couldn't query the server version, so we will assume the v8 client is not supported
useNewClient = true
} else if semver.Compare("v"+version, "v7.0.0") > -1 {
// if the server version is >= 7, we will use the v9 client
useNewClient = true
}
if useNewClient {
newClientFunc = newV9Client
if settings.Failover {
newClientFunc = newV9FailoverClient
}
}
c, err = newClientFunc(&settings)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
// start the token refresh goroutine
if settings.UseEntraID {
StartEntraIDTokenRefreshBackgroundRoutine(c, settings.Username, *tokenExpires, tokenCredential, logger)
}
return c, &settings, nil
}
func StartEntraIDTokenRefreshBackgroundRoutine(client RedisClient, username string, nextExpiration time.Time, cred *azcore.TokenCredential, logger *kitlogger.Logger) {
go func(cred *azcore.TokenCredential, username string, logger *kitlogger.Logger) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
backoffConfig := kitretry.DefaultConfig()
backoffConfig.MaxRetries = 3
backoffConfig.Policy = kitretry.PolicyExponential
var backoffManager backoff.BackOff
const refreshGracePeriod = 5 * time.Minute
tokenRefreshDuration := time.Until(nextExpiration.Add(-refreshGracePeriod))
(*logger).Debugf("redis client: starting entraID token refresh loop")
for {
(*logger).Debugf("redis client: next entraID token refresh: %v", tokenRefreshDuration)
select {
case <-ctx.Done():
(*logger).Infof("redis client: entraID token refresh stopped due to context cancellation")
return
case <-time.After(tokenRefreshDuration):
(*logger).Debug("redis client: refreshing entraID token")
// Get a new access token
backoffManager = backoffConfig.NewBackOffWithContext(ctx)
var token azcore.AccessToken
tokenErr := kitretry.NotifyRecover(
func() error {
var innerTokenErr error
token, innerTokenErr = (*cred).GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{"https://redis.azure.com/.default"},
})
return innerTokenErr
},
backoffManager,
func(err error, _ time.Duration) {
(*logger).Debugf("redis client: entraID token acquisition failed with error: %v. Retrying...", err)
},
func() {
(*logger).Debug("redis client: entraID token acquisition succeeded after error")
},
)
if tokenErr != nil {
_ = client.Close()
(*logger).Fatalf("redis client: entraID token acquisition failed: %v", tokenErr)
return
}
// Use the new access token via the Redis AUTH command
backoffManager = backoffConfig.NewBackOffWithContext(ctx)
authErr := kitretry.NotifyRecover(
func() error {
var innerAuthErr error
innerAuthErr = client.AuthACL(ctx, username, token.Token)
return innerAuthErr
},
backoffManager,
func(err error, _ time.Duration) {
(*logger).Debugf("redis client: entraID auth failed with error: %v. Retrying...", err)
},
func() {
(*logger).Debug("redis client: entraID auth succeeded after error")
},
)
if authErr != nil {
_ = client.Close()
(*logger).Fatalf("redis client: entraID auth failed: %v", authErr)
return
}
// Since the entraID auth succeeded we are setting the duration to wait for the next iteration of the refresh loop
(*logger).Debugf("redis client: entraID auth token successfully refreshed with the server")
tokenRefreshDuration = time.Until(token.ExpiresOn.Add(-refreshGracePeriod))
}
}
}(cred, username, logger)
}
func (s *Settings) GetEntraIDCredentialAndSetInitialTokenAsPassword(ctx context.Context, properties *map[string]string) (*time.Time, *azcore.TokenCredential, error) {
if len(s.Password) > 0 || len(s.Username) > 0 {
return nil, nil, errors.New(
"redis client configuration error: username or password must not be specified when using Entra ID authentication")
}
envSettings, err := azure.NewEnvironmentSettings(*properties)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
cred, err := envSettings.GetTokenCredential()
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{"https://redis.azure.com/.default"},
})
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
s.Password = token.Token
// This token has already been validated by EntraID. We use insecure parsing to get the object ID.
parsedToken, err := jwt.ParseString(token.Token, jwt.WithVerify(false), jwt.WithValidate(false))
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
objectID, found := parsedToken.Get("oid")
if found {
s.Username = objectID.(string)
} else {
return nil, nil, errors.New("redis client configuration error: could not parse object ID from Auth token")
}
return &token.ExpiresOn, &cred, nil
}
func ClientHasJSONSupport(c RedisClient) bool {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := c.DoWrite(ctx, "JSON.GET")
if err == nil {
return true
}
if strings.HasPrefix(err.Error(), "ERR unknown command") {
return false
}
return true
}
func GetServerVersion(c RedisClient) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := c.DoRead(ctx, "INFO", "server")
if err != nil {
return "", err
}
// get row in string res beginning with "redis_version"
rows := strings.Split(res.(string), "\n")
for _, row := range rows {
if strings.HasPrefix(row, "redis_version:") {
return strings.TrimSpace(strings.Split(row, ":")[1]), nil
}
}
return "", errors.New("could not find redis_version in redis info response")
}
// GetConnectedSlaves returns the number of slaves connected to the Redis master.
func GetConnectedSlaves(ctx context.Context, c RedisClient) (int, error) {
const connectedSlavesReplicas = "connected_slaves:"
res, err := c.DoRead(ctx, "INFO", "replication")
if err != nil {
return 0, err
}
// Response example: https://redis.io/commands/info#return-value
// # Replication\r\nrole:master\r\nconnected_slaves:1\r\n
s, _ := strconv.Unquote(fmt.Sprintf("%q", res))
if len(s) == 0 {
return 0, nil
}
infos := strings.Split(s, "\r\n")
for _, info := range infos {
if strings.HasPrefix(info, connectedSlavesReplicas) {
parsedReplicas, _ := strconv.ParseInt(info[len(connectedSlavesReplicas):], 10, 32)
return int(parsedReplicas), nil
}
}
return 0, nil
}
type RedisError string
func (e RedisError) Error() string { return string(e) }
func (RedisError) RedisError() {}