Removing internal config and retry packages that were moved to dapr/kit. (#988)

Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
This commit is contained in:
Phil Kedy 2021-07-02 23:17:52 -04:00 committed by GitHub
parent 5e146311b9
commit ff65172407
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1 additions and 1181 deletions

View File

@ -10,7 +10,7 @@ import (
"strconv"
"time"
"github.com/dapr/components-contrib/internal/config"
"github.com/dapr/kit/config"
)
type Settings struct {

View File

@ -1,182 +0,0 @@
package config
import (
"fmt"
"reflect"
"strconv"
"time"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
)
var (
typeDuration = reflect.TypeOf(time.Duration(5)) // nolint: gochecknoglobals
typeTime = reflect.TypeOf(time.Time{}) // nolint: gochecknoglobals
typeStringDecoder = reflect.TypeOf((*StringDecoder)(nil)).Elem() // nolint: gochecknoglobals
)
// StringDecoder is used as a way for custom types (or alias types) to
// override the basic decoding function in the `decodeString`
// DecodeHook. `encoding.TextMashaller` was not used because it
// matches many Go types and would have potentially unexpected results.
// Specifying a custom decoding func should be very intentional.
type StringDecoder interface {
DecodeString(value string) error
}
// Decode decodes generic map values from `input` to `output`, while providing helpful error information.
// `output` must be a pointer to a Go struct that contains `mapstructure` struct tags on fields that should
// be decoded. This function is useful when decoding values from configuration files parsed as
// `map[string]interface{}` or component metadata as `map[string]string`.
//
// Most of the heavy lifting is handled by the mapstructure library. A custom decoder is used to handle
// decoding string values to the supported primitives.
func Decode(input interface{}, output interface{}) error {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ // nolint:exhaustivestruct
Result: output,
DecodeHook: decodeString,
})
if err != nil {
return err
}
return decoder.Decode(input)
}
// nolint:cyclop
func decodeString(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if t.Kind() == reflect.String && f.Kind() != reflect.String {
return fmt.Sprintf("%v", data), nil
}
if f.Kind() == reflect.Ptr {
f = f.Elem()
data = reflect.ValueOf(data).Elem().Interface()
}
if f.Kind() != reflect.String {
return data, nil
}
dataString, ok := data.(string)
if !ok {
return nil, errors.Errorf("expected string: got %s", reflect.TypeOf(data))
}
var result interface{}
var decoder StringDecoder
if t.Implements(typeStringDecoder) {
result = reflect.New(t.Elem()).Interface()
decoder = result.(StringDecoder)
} else if reflect.PtrTo(t).Implements(typeStringDecoder) {
result = reflect.New(t).Interface()
decoder = result.(StringDecoder)
}
if decoder != nil {
if err := decoder.DecodeString(dataString); err != nil {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return nil, errors.Errorf("invalid %s %q: %v", t.Name(), dataString, err)
}
return result, nil
}
switch t {
case typeDuration:
// Check for simple integer value and if the value is positive treat it
// as milliseconds
if val, err := strconv.Atoi(dataString); err == nil {
return time.Duration(val) * time.Millisecond, nil
}
// Convert it by parsing
d, err := time.ParseDuration(dataString)
return d, invalidError(err, "duration", dataString)
case typeTime:
// Convert it by parsing
t, err := time.Parse(time.RFC3339Nano, dataString)
if err == nil {
return t, nil
}
t, err = time.Parse(time.RFC3339, dataString)
return t, invalidError(err, "time", dataString)
}
switch t.Kind() { // nolint: exhaustive
case reflect.Uint:
val, err := strconv.ParseUint(dataString, 10, 64)
return uint(val), invalidError(err, "uint", dataString)
case reflect.Uint64:
val, err := strconv.ParseUint(dataString, 10, 64)
return val, invalidError(err, "uint64", dataString)
case reflect.Uint32:
val, err := strconv.ParseUint(dataString, 10, 32)
return uint32(val), invalidError(err, "uint32", dataString)
case reflect.Uint16:
val, err := strconv.ParseUint(dataString, 10, 16)
return uint16(val), invalidError(err, "uint16", dataString)
case reflect.Uint8:
val, err := strconv.ParseUint(dataString, 10, 8)
return uint8(val), invalidError(err, "uint8", dataString)
case reflect.Int:
val, err := strconv.ParseInt(dataString, 10, 64)
return int(val), invalidError(err, "int", dataString)
case reflect.Int64:
val, err := strconv.ParseInt(dataString, 10, 64)
return val, invalidError(err, "int64", dataString)
case reflect.Int32:
val, err := strconv.ParseInt(dataString, 10, 32)
return int32(val), invalidError(err, "int32", dataString)
case reflect.Int16:
val, err := strconv.ParseInt(dataString, 10, 16)
return int16(val), invalidError(err, "int16", dataString)
case reflect.Int8:
val, err := strconv.ParseInt(dataString, 10, 8)
return int8(val), invalidError(err, "int8", dataString)
case reflect.Float32:
val, err := strconv.ParseFloat(dataString, 32)
return float32(val), invalidError(err, "float32", dataString)
case reflect.Float64:
val, err := strconv.ParseFloat(dataString, 64)
return val, invalidError(err, "float64", dataString)
case reflect.Bool:
val, err := strconv.ParseBool(dataString)
return val, invalidError(err, "bool", dataString)
default:
return data, nil
}
}
func invalidError(err error, msg, value string) error {
if err == nil {
return nil
}
return errors.Errorf("invalid %s %q", msg, value)
}

View File

@ -1,343 +0,0 @@
package config_test
import (
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/agrea/ptr"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/internal/config"
)
type testConfig struct { // nolint: maligned
Int int `mapstructure:"int"`
IntPtr *int `mapstructure:"intPtr"`
Int64 int64 `mapstructure:"int64"`
Int64Ptr *int64 `mapstructure:"int64Ptr"`
Int32 int32 `mapstructure:"int32"`
Int32Ptr *int32 `mapstructure:"int32Ptr"`
Int16 int16 `mapstructure:"int16"`
Int16Ptr *int16 `mapstructure:"int16Ptr"`
Int8 int8 `mapstructure:"int8"`
Int8Ptr *int8 `mapstructure:"int8Ptr"`
Uint uint `mapstructure:"uint"`
UintPtr *uint `mapstructure:"uintPtr"`
Uint64 uint64 `mapstructure:"uint64"`
Uint64Ptr *uint64 `mapstructure:"uint64Ptr"`
Uint32 uint32 `mapstructure:"uint32"`
Uint32Ptr *uint32 `mapstructure:"uint32Ptr"`
Uint16 uint16 `mapstructure:"uint16"`
Uint16Ptr *uint16 `mapstructure:"uint16Ptr"`
Byte byte `mapstructure:"byte"`
BytePtr *byte `mapstructure:"bytePtr"`
Float64 float64 `mapstructure:"float64"`
Float64Ptr *float64 `mapstructure:"float64Ptr"`
Float32 float32 `mapstructure:"float32"`
Float32Ptr *float32 `mapstructure:"float32Ptr"`
Bool bool `mapstructure:"bool"`
BoolPtr *bool `mapstructure:"boolPtr"`
Duration time.Duration `mapstructure:"duration"`
DurationPtr *time.Duration `mapstructure:"durationPtr"`
Time time.Time `mapstructure:"time"`
TimePtr *time.Time `mapstructure:"timePtr"`
String string `mapstructure:"string"`
StringPtr *string `mapstructure:"stringPtr"`
Decoded Decoded `mapstructure:"decoded"`
DecodedPtr *Decoded `mapstructure:"decodedPtr"`
Nested nested `mapstructure:"nested"`
NestedPtr *nested `mapstructure:"nestedPtr"`
}
type nested struct {
Integer int64 `mapstructure:"integer"`
String string `mapstructure:"string"`
}
type Decoded int
func (u *Decoded) DecodeString(text string) error {
if text == "unlimited" {
*u = -1
return nil
}
val, err := strconv.Atoi(text)
if err != nil {
return err
}
*u = Decoded(val)
return nil
}
func TestDecode(t *testing.T) {
timeVal := getTimeVal()
tests := map[string]interface{}{
"primitive values": map[string]interface{}{
"int": -9999,
"intPtr": ptr.Int(-9999),
"int64": -1234,
"int64Ptr": ptr.Int64(-12345),
"int32": -5678,
"int32Ptr": ptr.Int64(-5678),
"int16": -9012,
"int16Ptr": ptr.Int32(-9012),
"int8": -128,
"int8Ptr": ptr.Int8(-128),
"uint": 9999,
"uintPtr": ptr.Uint(9999),
"uint64": 1234,
"uint64Ptr": ptr.Uint64(1234),
"uint32": 5678,
"uint32Ptr": ptr.Uint64(5678),
"uint16": 9012,
"uint16Ptr": ptr.Uint64(9012),
"byte": 255,
"bytePtr": ptr.Byte(255),
"float64": 1234.5,
"float64Ptr": ptr.Float64(1234.5),
"float32": 6789.5,
"float32Ptr": ptr.Float64(6789.5),
"bool": true,
"boolPtr": ptr.Bool(true),
"duration": 5 * time.Second,
"durationPtr": durationPtr(5 * time.Second),
"time": timeVal,
"timePtr": timePtr(timeVal),
"string": 1234,
"stringPtr": ptr.String("1234"),
"decoded": "unlimited",
"decodedPtr": "unlimited",
"nested": map[string]interface{}{
"integer": 1234,
"string": 5678,
},
"nestedPtr": map[string]interface{}{
"integer": 1234,
"string": 5678,
},
},
"string values": map[string]interface{}{
"int": "-9999",
"intPtr": "-9999",
"int64": "-1234",
"int64Ptr": "-12345",
"int32": "-5678",
"int32Ptr": "-5678",
"int16": "-9012",
"int16Ptr": "-9012",
"int8": "-128",
"int8Ptr": "-128",
"uint": "9999",
"uintPtr": "9999",
"uint64": "1234",
"uint64Ptr": "1234",
"uint32": "5678",
"uint32Ptr": "5678",
"uint16": "9012",
"uint16Ptr": "9012",
"byte": "255",
"bytePtr": "255",
"float64": "1234.5",
"float64Ptr": "1234.5",
"float32": "6789.5",
"float32Ptr": "6789.5",
"bool": "true",
"boolPtr": "true",
"duration": "5000",
"durationPtr": "5s",
"time": "2021-01-02T15:04:05-07:00",
"timePtr": "2021-01-02T15:04:05-07:00",
"string": "1234",
"stringPtr": "1234",
"decoded": "unlimited",
"decodedPtr": "unlimited",
"nested": map[string]string{
"integer": "1234",
"string": "5678",
},
"nestedPtr": map[string]string{
"integer": "1234",
"string": "5678",
},
},
}
expected := getExpected()
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
var actual testConfig
err := config.Decode(tc, &actual)
assert.NoError(t, err)
assert.Equal(t, expected, actual)
})
}
}
func TestDecodeErrors(t *testing.T) {
var actual testConfig
err := config.Decode(map[string]interface{}{
"int": "-badval",
"intPtr": "-badval",
"int64": "-badval",
"int64Ptr": "-badval",
"int32": "-badval",
"int32Ptr": "-badval",
"int16": "-badval",
"int16Ptr": "-badval",
"int8": "-badval",
"int8Ptr": "-badval",
"uint": "-9999",
"uintPtr": "-9999",
"uint64": "-1234",
"uint64Ptr": "-1234",
"uint32": "-5678",
"uint32Ptr": "-5678",
"uint16": "-9012",
"uint16Ptr": "-9012",
"byte": "-1",
"bytePtr": "-1",
"float64": "badval.5",
"float64Ptr": "badval.5",
"float32": "badval.5",
"float32Ptr": "badval.5",
"bool": "badval",
"boolPtr": "badval",
"duration": "badval",
"durationPtr": "badval",
"time": "badval",
"timePtr": "badval",
"decoded": "badval",
"decodedPtr": "badval",
"string": 1234,
"stringPtr": 1234,
}, &actual)
if assert.Error(t, err) {
errMsg := err.Error()
expectedNumErrors := 32
expectedPrefix := " error(s) decoding:"
assert.True(t, strings.HasPrefix(errMsg, fmt.Sprintf("%d%s", expectedNumErrors, expectedPrefix)), errMsg)
prefixIndex := strings.Index(errMsg, expectedPrefix)
if assert.True(t, prefixIndex != -1) {
errMsg = errMsg[prefixIndex+len(expectedPrefix):]
errMsg = strings.TrimSpace(errMsg)
errors := strings.Split(errMsg, "\n")
errorSet := make(map[string]struct{}, len(errors))
for _, e := range errors {
errorSet[e] = struct{}{}
}
expectedErrors := []string{
"* error decoding 'int': invalid int \"-badval\"",
"* error decoding 'intPtr': invalid int \"-badval\"",
"* error decoding 'int16': invalid int16 \"-badval\"",
"* error decoding 'int16Ptr': invalid int16 \"-badval\"",
"* error decoding 'int32': invalid int32 \"-badval\"",
"* error decoding 'int32Ptr': invalid int32 \"-badval\"",
"* error decoding 'int64': invalid int64 \"-badval\"",
"* error decoding 'int64Ptr': invalid int64 \"-badval\"",
"* error decoding 'int8': invalid int8 \"-badval\"",
"* error decoding 'int8Ptr': invalid int8 \"-badval\"",
"* error decoding 'uint': invalid uint \"-9999\"",
"* error decoding 'uintPtr': invalid uint \"-9999\"",
"* error decoding 'uint64': invalid uint64 \"-1234\"",
"* error decoding 'uint64Ptr': invalid uint64 \"-1234\"",
"* error decoding 'uint32': invalid uint32 \"-5678\"",
"* error decoding 'uint32Ptr': invalid uint32 \"-5678\"",
"* error decoding 'uint16': invalid uint16 \"-9012\"",
"* error decoding 'uint16Ptr': invalid uint16 \"-9012\"",
"* error decoding 'byte': invalid uint8 \"-1\"",
"* error decoding 'bytePtr': invalid uint8 \"-1\"",
"* error decoding 'float32': invalid float32 \"badval.5\"",
"* error decoding 'float32Ptr': invalid float32 \"badval.5\"",
"* error decoding 'float64': invalid float64 \"badval.5\"",
"* error decoding 'float64Ptr': invalid float64 \"badval.5\"",
"* error decoding 'duration': invalid duration \"badval\"",
"* error decoding 'durationPtr': invalid duration \"badval\"",
"* error decoding 'time': invalid time \"badval\"",
"* error decoding 'timePtr': invalid time \"badval\"",
"* error decoding 'decoded': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'decodedPtr': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'bool': invalid bool \"badval\"",
"* error decoding 'boolPtr': invalid bool \"badval\"",
}
for _, expectedError := range expectedErrors {
assert.Contains(t, errors, expectedError)
delete(errorSet, expectedError)
}
assert.Empty(t, errorSet)
}
}
}
func durationPtr(value time.Duration) *time.Duration {
return &value
}
func timePtr(value time.Time) *time.Time {
return &value
}
func decodedPtr(value Decoded) *Decoded {
return &value
}
func getTimeVal() time.Time {
timeVal, _ := time.Parse(time.RFC3339, "2021-01-02T15:04:05-07:00")
return timeVal
}
func getExpected() testConfig {
timeVal := getTimeVal()
return testConfig{
Int: -9999,
IntPtr: ptr.Int(-9999),
Int64: -1234,
Int64Ptr: ptr.Int64(-12345),
Int32: -5678,
Int32Ptr: ptr.Int32(-5678),
Int16: -9012,
Int16Ptr: ptr.Int16(-9012),
Int8: -128,
Int8Ptr: ptr.Int8(-128),
Uint: 9999,
UintPtr: ptr.Uint(9999),
Uint64: 1234,
Uint64Ptr: ptr.Uint64(1234),
Uint32: 5678,
Uint32Ptr: ptr.Uint32(5678),
Uint16: 9012,
Uint16Ptr: ptr.Uint16(9012),
Byte: 255,
BytePtr: ptr.Byte(255),
Float64: 1234.5,
Float64Ptr: ptr.Float64(1234.5),
Float32: 6789.5,
Float32Ptr: ptr.Float32(6789.5),
Bool: true,
BoolPtr: ptr.Bool(true),
Duration: 5 * time.Second,
DurationPtr: durationPtr(5 * time.Second),
Time: timeVal,
TimePtr: timePtr(timeVal),
String: "1234",
StringPtr: ptr.String("1234"),
Decoded: -1,
DecodedPtr: decodedPtr(-1),
Nested: nested{
Integer: 1234,
String: "5678",
},
NestedPtr: &nested{
Integer: 1234,
String: "5678",
},
}
}

View File

@ -1,49 +0,0 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation and Dapr Contributors.
// Licensed under the MIT License.
// ------------------------------------------------------------
package config
import (
"fmt"
)
// Normalize converts map[interface{}]interface{} to map[string]interface{} to normalize
// for JSON and usage in component initialization.
// nolint:cyclop
func Normalize(i interface{}) (interface{}, error) {
var err error
switch x := i.(type) {
case map[interface{}]interface{}:
m2 := map[string]interface{}{}
for k, v := range x {
if strKey, ok := k.(string); ok {
if m2[strKey], err = Normalize(v); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("error parsing config field: %v", k)
}
}
return m2, nil
case map[string]interface{}:
m2 := map[string]interface{}{}
for k, v := range x {
if m2[k], err = Normalize(v); err != nil {
return nil, err
}
}
return m2, nil
case []interface{}:
for i, v := range x {
if x[i], err = Normalize(v); err != nil {
return nil, err
}
}
}
return i, nil
}

View File

@ -1,98 +0,0 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation and Dapr Contributors.
// Licensed under the MIT License.
// ------------------------------------------------------------
package config_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/internal/config"
)
func TestNormalize(t *testing.T) {
tests := map[string]struct {
input interface{}
expected interface{}
err string
}{
"simple": {input: "test", expected: "test"},
"map of string to interface{}": {
input: map[string]interface{}{
"test": "1234",
"nested": map[string]interface{}{
"value": "5678",
},
}, expected: map[string]interface{}{
"test": "1234",
"nested": map[string]interface{}{
"value": "5678",
},
},
},
"map of string to interface{} with error": {
input: map[string]interface{}{
"test": "1234",
"nested": map[interface{}]interface{}{
5678: "5678",
},
}, err: "error parsing config field: 5678",
},
"map of interface{} to interface{}": {
input: map[string]interface{}{
"test": "1234",
"nested": map[interface{}]interface{}{
"value": "5678",
},
}, expected: map[string]interface{}{
"test": "1234",
"nested": map[string]interface{}{
"value": "5678",
},
},
},
"map of interface{} to interface{} with error": {
input: map[interface{}]interface{}{
"test": "1234",
"nested": map[interface{}]interface{}{
5678: "5678",
},
}, err: "error parsing config field: 5678",
},
"slice of interface{}": {
input: []interface{}{
map[interface{}]interface{}{
"value": "5678",
},
}, expected: []interface{}{
map[string]interface{}{
"value": "5678",
},
},
},
"slice of interface{} with error": {
input: []interface{}{
map[interface{}]interface{}{
1234: "1234",
},
}, err: "error parsing config field: 1234",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
actual, err := config.Normalize(tc.input)
if tc.err != "" {
require.Error(t, err)
assert.EqualError(t, err, tc.err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tc.expected, actual)
})
}
}

View File

@ -1,53 +0,0 @@
package config
import (
"strings"
"unicode"
)
func PrefixedBy(input interface{}, prefix string) (interface{}, error) {
normalized, err := Normalize(input)
if err != nil {
// The only error that can come from normalize is if
// input is a map[interface{}]interface{} and contains
// a key that is not a string.
return input, err
}
input = normalized
if inputMap, ok := input.(map[string]interface{}); ok {
converted := make(map[string]interface{}, len(inputMap))
for k, v := range inputMap {
if strings.HasPrefix(k, prefix) {
key := uncapitalize(strings.TrimPrefix(k, prefix))
converted[key] = v
}
}
return converted, nil
} else if inputMap, ok := input.(map[string]string); ok {
converted := make(map[string]string, len(inputMap))
for k, v := range inputMap {
if strings.HasPrefix(k, prefix) {
key := uncapitalize(strings.TrimPrefix(k, prefix))
converted[key] = v
}
}
return converted, nil
}
return input, nil
}
// uncapitalize initial capital letters in `str`.
func uncapitalize(str string) string {
if len(str) == 0 {
return str
}
vv := []rune(str) // Introduced later
vv[0] = unicode.ToLower(vv[0])
return string(vv)
}

View File

@ -1,58 +0,0 @@
package config_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/internal/config"
)
func TestPrefixedBy(t *testing.T) {
tests := map[string]struct {
prefix string
input interface{}
expected interface{}
err string
}{
"map of string to string": {
prefix: "test",
input: map[string]string{
"": "",
"ignore": "don't include me",
"testOne": "include me",
"testTwo": "and me",
},
expected: map[string]string{
"one": "include me",
"two": "and me",
},
},
"map of string to interface{}": {
prefix: "test",
input: map[string]interface{}{
"": "",
"ignore": "don't include me",
"testOne": "include me",
"testTwo": "and me",
},
expected: map[string]interface{}{
"one": "include me",
"two": "and me",
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
actual, err := config.PrefixedBy(tc.input, tc.prefix)
if tc.err != "" {
if assert.Error(t, err) {
assert.Equal(t, tc.err, err.Error())
}
} else {
assert.Equal(t, tc.expected, actual, "unexpected output")
}
})
}
}

View File

@ -1,167 +0,0 @@
package retry
import (
"context"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pkg/errors"
"github.com/dapr/components-contrib/internal/config"
)
// PolicyType denotes if the back off delay should be constant or exponential.
type PolicyType int
const (
// PolicyConstant is a backoff policy that always returns the same backoff delay.
PolicyConstant PolicyType = iota
// PolicyExponential is a backoff implementation that increases the backoff period
// for each retry attempt using a randomization function that grows exponentially.
PolicyExponential
)
// Config encapsulates the back off policy configuration.
type Config struct {
Policy PolicyType `mapstructure:"policy"`
// Constant back off
Duration time.Duration `mapstructure:"duration"`
// Exponential back off
InitialInterval time.Duration `mapstructure:"initialInterval"`
RandomizationFactor float32 `mapstructure:"randomizationFactor"`
Multiplier float32 `mapstructure:"multiplier"`
MaxInterval time.Duration `mapstructure:"maxInterval"`
MaxElapsedTime time.Duration `mapstructure:"maxElapsedTime"`
// Additional options
MaxRetries int64 `mapstructure:"maxRetries"`
}
// DefaultConfig represents the default configuration for a
// `Config`.
func DefaultConfig() Config {
return Config{
Policy: PolicyConstant,
Duration: 5 * time.Second,
InitialInterval: backoff.DefaultInitialInterval,
RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: backoff.DefaultMaxInterval,
MaxElapsedTime: backoff.DefaultMaxElapsedTime,
MaxRetries: -1,
}
}
// DefaultConfigWithNoRetry represents the default configuration with `MaxRetries` set to 0.
// This may be useful for those brokers which can handles retries on its own.
func DefaultConfigWithNoRetry() Config {
c := DefaultConfig()
c.MaxRetries = 0
return c
}
// DecodeConfig decodes a Go struct into a `Config`.
func DecodeConfig(c *Config, input interface{}) error {
// Use the deefault config if `c` is empty/zero value.
var emptyConfig Config
if *c == emptyConfig {
*c = DefaultConfig()
}
return config.Decode(input, c)
}
// DecodeConfigWithPrefix decodes a Go struct into a `Config`.
func DecodeConfigWithPrefix(c *Config, input interface{}, prefix string) error {
input, err := config.PrefixedBy(input, prefix)
if err != nil {
return err
}
return DecodeConfig(c, input)
}
// NewBackOff returns a BackOff instance for use with `NotifyRecover`
// or `backoff.RetryNotify` directly. The instance will not stop due to
// context cancellation. To support cancellation (recommended), use
// `NewBackOffWithContext`.
//
// Since the underlying backoff implementations are not always thread safe,
// `NewBackOff` or `NewBackOffWithContext` should be called each time
// `RetryNotifyRecover` or `backoff.RetryNotify` is used.
func (c *Config) NewBackOff() backoff.BackOff {
var b backoff.BackOff
switch c.Policy {
case PolicyConstant:
b = backoff.NewConstantBackOff(c.Duration)
case PolicyExponential:
eb := backoff.NewExponentialBackOff()
eb.InitialInterval = c.InitialInterval
eb.RandomizationFactor = float64(c.RandomizationFactor)
eb.Multiplier = float64(c.Multiplier)
eb.MaxInterval = c.MaxInterval
eb.MaxElapsedTime = c.MaxElapsedTime
b = eb
}
if c.MaxRetries >= 0 {
b = backoff.WithMaxRetries(b, uint64(c.MaxRetries))
}
return b
}
// NewBackOffWithContext returns a BackOff instance for use with `RetryNotifyRecover`
// or `backoff.RetryNotify` directly. The provided context is used to cancel retries
// if it is canceled.
//
// Since the underlying backoff implementations are not always thread safe,
// `NewBackOff` or `NewBackOffWithContext` should be called each time
// `RetryNotifyRecover` or `backoff.RetryNotify` is used.
func (c *Config) NewBackOffWithContext(ctx context.Context) backoff.BackOff {
b := c.NewBackOff()
return backoff.WithContext(b, ctx)
}
// NotifyRecover is a wrapper around backoff.RetryNotify that adds another callback for when an operation
// previously failed but has since recovered. The main purpose of this wrapper is to call `notify` only when
// the operations fails the first time and `recovered` when it finally succeeds. This can be helpful in limiting
// log messages to only the events that operators need to be alerted on.
func NotifyRecover(operation backoff.Operation, b backoff.BackOff, notify backoff.Notify, recovered func()) error {
var notified bool
return backoff.RetryNotify(func() error {
err := operation()
if err == nil && notified {
notified = false
recovered()
}
return err
}, b, func(err error, d time.Duration) {
if !notified {
notify(err, d)
notified = true
}
})
}
// DecodeString handles converting a string value to `p`.
func (p *PolicyType) DecodeString(value string) error {
switch strings.ToLower(value) {
case "constant":
*p = PolicyConstant
case "exponential":
*p = PolicyExponential
default:
return errors.Errorf("unexpected back off policy type: %s", value)
}
return nil
}

View File

@ -1,230 +0,0 @@
package retry_test
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/internal/retry"
)
var errRetry = errors.New("Testing")
func TestDecode(t *testing.T) {
tests := map[string]struct {
config interface{}
overrides func(config *retry.Config)
err string
}{
"invalid policy type": {
config: map[string]interface{}{
"backOffPolicy": "invalid",
},
overrides: nil,
err: "1 error(s) decoding:\n\n* error decoding 'policy': invalid PolicyType \"invalid\": unexpected back off policy type: invalid",
},
"default": {
config: map[string]interface{}{},
overrides: nil,
err: "",
},
"constant default": {
config: map[string]interface{}{
"backOffPolicy": "constant",
},
overrides: nil,
err: "",
},
"constant with duraction": {
config: map[string]interface{}{
"backOffPolicy": "constant",
"backOffDuration": "10s",
},
overrides: func(config *retry.Config) {
config.Duration = 10 * time.Second
},
err: "",
},
"exponential default": {
config: map[string]interface{}{
"backOffPolicy": "exponential",
},
overrides: func(config *retry.Config) {
config.Policy = retry.PolicyExponential
},
err: "",
},
"exponential with string settings": {
config: map[string]interface{}{
"backOffPolicy": "exponential",
"backOffInitialInterval": "1000", // 1s
"backOffRandomizationFactor": "1.0",
"backOffMultiplier": "2.0",
"backOffMaxInterval": "120000", // 2m
"backOffMaxElapsedTime": "1800000", // 30m
},
overrides: func(config *retry.Config) {
config.Policy = retry.PolicyExponential
config.InitialInterval = 1 * time.Second
config.RandomizationFactor = 1.0
config.Multiplier = 2.0
config.MaxInterval = 2 * time.Minute
config.MaxElapsedTime = 30 * time.Minute
},
err: "",
},
"exponential with typed settings": {
config: map[string]interface{}{
"backOffPolicy": "exponential",
"backOffInitialInterval": "1000ms", // 1s
"backOffRandomizationFactor": 1.0,
"backOffMultiplier": 2.0,
"backOffMaxInterval": "120s", // 2m
"backOffMaxElapsedTime": "30m", // 30m
},
overrides: func(config *retry.Config) {
config.Policy = retry.PolicyExponential
config.InitialInterval = 1 * time.Second
config.RandomizationFactor = 1.0
config.Multiplier = 2.0
config.MaxInterval = 2 * time.Minute
config.MaxElapsedTime = 30 * time.Minute
},
err: "",
},
"map[string]string settings": {
config: map[string]string{
"backOffPolicy": "exponential",
"backOffInitialInterval": "1000ms", // 1s
"backOffRandomizationFactor": "1.0",
"backOffMultiplier": "2.0",
"backOffMaxInterval": "120s", // 2m
"backOffMaxElapsedTime": "30m", // 30m
},
overrides: func(config *retry.Config) {
config.Policy = retry.PolicyExponential
config.InitialInterval = 1 * time.Second
config.RandomizationFactor = 1.0
config.Multiplier = 2.0
config.MaxInterval = 2 * time.Minute
config.MaxElapsedTime = 30 * time.Minute
},
err: "",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
var actual retry.Config
err := retry.DecodeConfigWithPrefix(&actual, tc.config, "backOff")
if tc.err != "" {
if assert.Error(t, err) {
assert.Equal(t, tc.err, err.Error())
}
} else {
config := retry.DefaultConfig()
if tc.overrides != nil {
tc.overrides(&config)
}
assert.Equal(t, config, actual, "unexpected decoded configuration")
}
})
}
}
func TestRetryNotifyRecoverMaxRetries(t *testing.T) {
config := retry.DefaultConfig()
config.MaxRetries = 3
config.Duration = 1
var operationCalls, notifyCalls, recoveryCalls int
b := config.NewBackOff()
err := retry.NotifyRecover(func() error {
operationCalls++
return errRetry
}, b, func(err error, d time.Duration) {
notifyCalls++
}, func() {
recoveryCalls++
})
assert.Error(t, err)
assert.Equal(t, errRetry, err)
assert.Equal(t, 4, operationCalls)
assert.Equal(t, 1, notifyCalls)
assert.Equal(t, 0, recoveryCalls)
}
func TestRetryNotifyRecoverRecovery(t *testing.T) {
config := retry.DefaultConfig()
config.MaxRetries = 3
config.Duration = 1
var operationCalls, notifyCalls, recoveryCalls int
b := config.NewBackOff()
err := retry.NotifyRecover(func() error {
operationCalls++
if operationCalls >= 2 {
return nil
}
return errRetry
}, b, func(err error, d time.Duration) {
notifyCalls++
}, func() {
recoveryCalls++
})
assert.NoError(t, err)
assert.Equal(t, 2, operationCalls)
assert.Equal(t, 1, notifyCalls)
assert.Equal(t, 1, recoveryCalls)
}
func TestRetryNotifyRecoverCancel(t *testing.T) {
config := retry.DefaultConfig()
config.Policy = retry.PolicyConstant
config.Duration = 1 * time.Minute
var notifyCalls, recoveryCalls int
ctx, cancel := context.WithCancel(context.Background())
b := config.NewBackOffWithContext(ctx)
errC := make(chan error, 1)
startedC := make(chan struct{}, 100)
go func() {
errC <- retry.NotifyRecover(func() error {
return errRetry
}, b, func(err error, d time.Duration) {
notifyCalls++
startedC <- struct{}{}
}, func() {
recoveryCalls++
})
}()
<-startedC
cancel()
err := <-errC
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.Equal(t, 1, notifyCalls)
assert.Equal(t, 0, recoveryCalls)
}
func TestCheckEmptyConfig(t *testing.T) {
var config retry.Config
err := retry.DecodeConfig(&config, map[string]interface{}{})
assert.NoError(t, err)
defaultConfig := retry.DefaultConfig()
assert.Equal(t, config, defaultConfig)
}