Misc refactorings

- Removed `state.DeleteWithOptions` and `state.SetWithOptions` which were useless at this point (likely a leftover from when there were retries)
- Some improvements in etag handling in postgres
- Other minor refactorings

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-11-17 16:29:29 +00:00
parent d366211810
commit 1e728597b4
13 changed files with 121 additions and 215 deletions

View File

@ -130,6 +130,7 @@ func TestPublishingWithTTL(t *testing.T) {
const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
@ -137,6 +138,7 @@ func TestPublishingWithTTL(t *testing.T) {
"deleteWhenUnused": strconv.FormatBool(exclusive),
"durable": strconv.FormatBool(durable),
},
},
}
logger := logger.NewLogger("test")
@ -162,7 +164,7 @@ func TestPublishingWithTTL(t *testing.T) {
},
}
_, err = rabbitMQBinding1.Invoke(context.Backgound(), &writeRequest)
_, err = rabbitMQBinding1.Invoke(context.Background(), &writeRequest)
assert.Nil(t, err)
time.Sleep(time.Second + (ttlInSeconds * time.Second))
@ -183,7 +185,7 @@ func TestPublishingWithTTL(t *testing.T) {
contribMetadata.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000),
},
}
_, err = rabbitMQBinding2.Invoke(context.Backgound(), &writeRequest)
_, err = rabbitMQBinding2.Invoke(context.Background(), &writeRequest)
assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
@ -204,6 +206,7 @@ func TestExclusiveQueue(t *testing.T) {
const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
@ -213,6 +216,7 @@ func TestExclusiveQueue(t *testing.T) {
"exclusive": strconv.FormatBool(exclusive),
contribMetadata.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10),
},
},
}
logger := logger.NewLogger("test")
@ -257,6 +261,7 @@ func TestPublishWithPriority(t *testing.T) {
const maxPriority = 10
metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
@ -265,6 +270,7 @@ func TestPublishWithPriority(t *testing.T) {
"durable": strconv.FormatBool(durable),
"maxPriority": strconv.FormatInt(maxPriority, 10),
},
},
}
logger := logger.NewLogger("test")
@ -283,7 +289,7 @@ func TestPublishWithPriority(t *testing.T) {
defer ch.Close()
const middlePriorityMsgContent = "middle"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{
_, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "5",
},
@ -292,7 +298,7 @@ func TestPublishWithPriority(t *testing.T) {
assert.Nil(t, err)
const lowPriorityMsgContent = "low"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{
_, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "1",
},
@ -301,7 +307,7 @@ func TestPublishWithPriority(t *testing.T) {
assert.Nil(t, err)
const highPriorityMsgContent = "high"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{
_, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "10",
},

View File

@ -52,11 +52,11 @@ func createIotHubPubsubMetadata() pubsub.Metadata {
metadata := pubsub.Metadata{
Base: metadata.Base{
Properties: map[string]string{
connectionString: os.Getenv(iotHubConnectionStringEnvKey),
consumerID: os.Getenv(iotHubConsumerGroupEnvKey),
storageAccountName: os.Getenv(storageAccountNameEnvKey),
storageAccountKey: os.Getenv(storageAccountKeyEnvKey),
storageContainerName: testStorageContainerName,
"connectionString": os.Getenv(iotHubConnectionStringEnvKey),
"consumerID": os.Getenv(iotHubConsumerGroupEnvKey),
"storageAccountName": os.Getenv(storageAccountNameEnvKey),
"storageAccountKey": os.Getenv(storageAccountKeyEnvKey),
"storageContainerName": testStorageContainerName,
},
},
}

View File

@ -114,11 +114,6 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
// Set makes an insert or update to the database.
func (p *cockroachDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
p.logger.Debug("Setting state value in CockroachDB")
value, isBinary, err := validateAndReturnValue(req)
@ -240,11 +235,6 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
// Delete removes an item from the state store.
func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from CockroachDB")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")

View File

@ -113,7 +113,8 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil
}
func (f *Firestore) setValue(req *state.SetRequest) error {
// Set saves state into Firestore.
func (f *Firestore) Set(req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -142,12 +143,8 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
return nil
}
// Set saves state into Firestore with retry.
func (f *Firestore) Set(req *state.SetRequest) error {
return state.SetWithOptions(f.setValue, req)
}
func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
// Delete performs a delete operation.
func (f *Firestore) Delete(req *state.DeleteRequest) error {
ctx := context.Background()
key := datastore.NameKey(f.entityKind, req.Key, nil)
@ -159,11 +156,6 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
return nil
}
// Delete performs a delete operation.
func (f *Firestore) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(f.deleteValue, req)
}
func getFirestoreMetadata(meta state.Metadata) (*firestoreMetadata, error) {
m := firestoreMetadata{
EntityKind: defaultEntityKind,

View File

@ -146,7 +146,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) {
return nil, nil
}
func (m *Memcached) setValue(req *state.SetRequest) error {
func (m *Memcached) Set(req *state.SetRequest) error {
var bt []byte
ttl, err := m.parseTTL(req)
if err != nil {
@ -194,10 +194,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil
}
func (m *Memcached) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req)
}
func (m *Memcached) GetComponentMetadata() map[string]string {
metadataStruct := memcachedMetadata{}
metadataInfo := map[string]string{}

View File

@ -30,7 +30,6 @@ import (
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
// Optimistic Concurrency is implemented using a string column that stores a UUID.
@ -337,8 +336,6 @@ func (m *MySQL) Delete(req *state.DeleteRequest) error {
return m.deleteValue(m.db, req)
}
// deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql")
@ -458,14 +455,14 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
return &state.GetResponse{
Data: data,
ETag: ptr.Of(eTag),
ETag: &eTag,
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: value,
ETag: ptr.Of(eTag),
ETag: &eTag,
Metadata: req.Metadata,
}, nil
}
@ -476,8 +473,6 @@ func (m *MySQL) Set(req *state.SetRequest) error {
return m.setValue(m.db, req)
}
// setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func.
func (m *MySQL) setValue(querier querier, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql")

View File

@ -114,11 +114,6 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
return nil
}
// Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(o.setValue, req)
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
@ -133,8 +128,8 @@ func parseTTL(requestMetadata map[string]string) (*int, error) {
return nil, nil
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
// Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error {
o.logger.Debug("Setting state value in OracleDatabase")
err := state.CheckRequestOptions(req.Options)
if err != nil {
@ -204,6 +199,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
} else {
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
// TODO: Needs to update ttl_in_seconds
updateStatement := fmt.Sprintf(
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
WHERE key = :key AND etag = :etag`,
@ -273,11 +269,6 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
// Delete removes an item from the state store.
func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(o.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
o.logger.Debug("Deleting state value from OracleDatabase")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
@ -354,7 +345,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
return err
}
// Close implements io.Close.
// Close implements io.Closer.
func (o *oracleDatabaseAccess) Close() error {
if o.db != nil {
return o.db.Close()
@ -391,10 +382,3 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := tblCount > 0
return exists, err
}
// func handleError(msg string, err error) {
// if err != nil {
// fmt.Println(msg, err)
// }
// }

View File

@ -17,6 +17,7 @@ import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
@ -77,7 +78,7 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
if m.ConnectionString == "" {
p.logger.Error("Missing postgreSQL connection string")
return fmt.Errorf(errMissingConnectionString)
return errors.New(errMissingConnectionString)
}
p.connectionString = m.ConnectionString
@ -111,11 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options)
@ -124,11 +120,11 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
}
if req.Key == "" {
return fmt.Errorf("missing key in set operation")
return errors.New("missing key in set operation")
}
if v, ok := req.Value.(string); ok && v == "" {
return fmt.Errorf("empty string is not allowed in set operation")
return errors.New("empty string is not allowed in set operation")
}
v := req.Value
@ -149,7 +145,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
result, err = p.db.Exec(fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`,
p.tableName), req.Key, value, isBinary)
} else if req.ETag == nil {
} else if req.ETag == nil || *req.ETag == "" {
result, err = p.db.Exec(fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
@ -184,7 +180,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
}
if rows != 1 {
return fmt.Errorf("no item was updated")
return errors.New("no item was updated")
}
return nil
@ -218,27 +214,30 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
return nil, errors.New("missing key in get operation")
}
var value string
var isBinary bool
var etag int
var (
value []byte
isBinary bool
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
)
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
return &state.GetResponse{}, nil
}
return nil, err
}
if isBinary {
var s string
var data []byte
var (
s string
data []byte
)
if err = json.Unmarshal([]byte(value), &s); err != nil {
if err = json.Unmarshal(value, &s); err != nil {
return nil, err
}
@ -248,34 +247,28 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
return &state.GetResponse{
Data: data,
ETag: ptr.Of(strconv.Itoa(etag)),
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: []byte(value),
ETag: ptr.Of(strconv.Itoa(etag)),
Data: value,
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
Metadata: req.Metadata,
}, nil
}
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
return errors.New("missing key in delete operation")
}
var result sql.Result
var err error
if req.ETag == nil {
if req.ETag == nil || *req.ETag == "" {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
@ -313,12 +306,10 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
}
if len(req) > 0 {
for _, d := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da)
for i := range req {
err = p.Delete(&req[i])
if err != nil {
tx.Rollback()
return err
}
}
@ -446,11 +437,11 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return setReq, fmt.Errorf("expecting set request")
return setReq, errors.New("expecting set request")
}
if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation")
return setReq, errors.New("missing key in upsert operation")
}
return setReq, nil
@ -460,11 +451,11 @@ func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) {
func getDelete(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return delReq, fmt.Errorf("expecting delete request")
return delReq, errors.New("expecting delete request")
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation")
return delReq, errors.New("missing key in upsert operation")
}
return delReq, nil

View File

@ -88,9 +88,9 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error)
}
}
sep := fmt.Sprintf(" %s ", op)
sep := " " + op + " "
return fmt.Sprintf("(%s)", strings.Join(arr, sep)), nil
return "(" + strings.Join(arr, sep) + ")", nil
}
func (q *Query) VisitAND(f *query.AND) (string, error) {
@ -102,10 +102,10 @@ func (q *Query) VisitOR(f *query.OR) (string, error) {
}
func (q *Query) Finalize(filters string, qq *query.Query) error {
q.query = fmt.Sprintf("SELECT key, value, xmin as etag FROM %s", q.tableName)
q.query = "SELECT key, value, xmin as etag FROM " + q.tableName
if filters != "" {
q.query += fmt.Sprintf(" WHERE %s", filters)
q.query += " WHERE " + filters
}
if len(qq.Sort) > 0 {
@ -117,13 +117,13 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
}
q.query += translateFieldToFilter(sortItem.Key)
if sortItem.Order != "" {
q.query += fmt.Sprintf(" %s", sortItem.Order)
q.query += " " + sortItem.Order
}
}
}
if qq.Page.Limit > 0 {
q.query += fmt.Sprintf(" LIMIT %d", qq.Page.Limit)
q.query += " LIMIT " + strconv.Itoa(qq.Page.Limit)
q.limit = qq.Page.Limit
}
@ -132,7 +132,7 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
if err != nil {
return err
}
q.query += fmt.Sprintf(" OFFSET %d", skip)
q.query += " OFFSET " + strconv.FormatInt(skip, 10)
q.skip = &skip
}
@ -151,7 +151,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st
var (
key string
data []byte
etag int
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
)
if err = rows.Scan(&key, &data, &etag); err != nil {
return nil, "", err
@ -159,7 +159,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st
result := state.QueryItem{
Key: key,
Data: data,
ETag: ptr.Of(strconv.Itoa(etag)),
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
}
ret = append(ret, result)
}
@ -200,7 +200,7 @@ func translateFieldToFilter(key string) string {
filterField += ">"
}
filterField += fmt.Sprintf("'%s'", fieldPart)
filterField += "'" + fieldPart + "'"
}
return filterField
@ -209,6 +209,6 @@ func translateFieldToFilter(key string) string {
func (q *Query) whereFieldEqual(key string, value interface{}) string {
position := q.addParamValueAndReturnPosition(value)
filterField := translateFieldToFilter(key)
query := fmt.Sprintf("%s=$%v", filterField, position)
query := filterField + "=$" + strconv.Itoa(position)
return query
}

View File

@ -196,7 +196,13 @@ func (r *StateStore) parseConnectedSlaves(res string) int {
return 0
}
func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
// Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
if req.ETag == nil {
etag := "0"
req.ETag = &etag
@ -208,7 +214,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
} else {
delQuery = delDefaultQuery
}
_, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
_, err = r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
if err != nil {
return state.NewETagError(state.ETagMismatch, err)
}
@ -216,16 +222,6 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
return nil
}
// Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
return state.DeleteWithOptions(r.deleteValue, req)
}
func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "GET", req.Key).Result()
if err != nil {
@ -318,7 +314,8 @@ type jsonEntry struct {
Version *int `json:"version,omitempty"`
}
func (r *StateStore) setValue(req *state.SetRequest) error {
// Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -384,11 +381,6 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
return nil
}
// Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error {
return state.SetWithOptions(r.setValue, req)
}
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
var setQuery, delQuery string

View File

@ -66,13 +66,3 @@ func validateConsistencyOption(c string) error {
return nil
}
// SetWithOptions handles SetRequest with request options.
func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error {
return method(req)
}
// DeleteWithOptions handles DeleteRequest with options.
func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error {
return method(req)
}

View File

@ -19,31 +19,6 @@ import (
"github.com/stretchr/testify/assert"
)
// TestSetRequestWithOptions is used to test request options.
func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with default options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
counter++
return nil
}, &SetRequest{})
assert.Equal(t, 1, counter, "should execute only once")
})
t.Run("set with no explicit options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
counter++
return nil
}, &SetRequest{
Options: SetStateOption{},
})
assert.Equal(t, 1, counter, "should execute only once")
})
}
// TestCheckRequestOptions is used to validate request options.
func TestCheckRequestOptions(t *testing.T) {
t.Run("set state options", func(t *testing.T) {

View File

@ -187,8 +187,7 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
return err
}
return state.DeleteWithOptions(func(req *state.DeleteRequest) error {
err := s.conn.Delete(r.Path, r.Version)
err = s.conn.Delete(r.Path, r.Version)
if errors.Is(err, zk.ErrNoNode) {
return nil
}
@ -202,7 +201,6 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
}
return nil
}, req)
}
// BulkDelete performs a bulk delete operation.
@ -239,9 +237,7 @@ func (s *StateStore) Set(req *state.SetRequest) error {
return err
}
return state.SetWithOptions(func(req *state.SetRequest) error {
_, err = s.conn.Set(r.Path, r.Data, r.Version)
if errors.Is(err, zk.ErrNoNode) {
_, err = s.conn.Create(r.Path, r.Data, 0, nil)
}
@ -255,7 +251,6 @@ func (s *StateStore) Set(req *state.SetRequest) error {
}
return nil
}, req)
}
// BulkSet performs a bulks save operation.